From 328d56137e86dcafd2568b80fd24b60a8f942e87 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 9 Jan 2026 16:26:12 +0800 Subject: [PATCH 01/22] init --- cookbook/megatron/__init__.py | 3 + cookbook/megatron/client.py | 287 ++++ cookbook/megatron/lora.py | 217 +++ cookbook/megatron/server.py | 270 ++++ cookbook/sft/lora.py | 2 +- src/twinkle/infra/__init__.py | 2 +- src/twinkle/infra/ray/resource_manager.py | 6 +- src/twinkle/loss/__init__.py | 3 + .../loss/vocab_parallel_cross_entropy.py | 87 ++ src/twinkle/megatron/__init__.py | 107 ++ src/twinkle/megatron/model/__init__.py | 52 + src/twinkle/megatron/model/bridge.py | 1298 +++++++++++++++++ src/twinkle/megatron/model/initializer.py | 325 +++++ src/twinkle/megatron/model/qwen3.py | 64 + src/twinkle/megatron/tuners/__init__.py | 9 + src/twinkle/megatron/tuners/lora.py | 606 ++++++++ src/twinkle/megatron/utils.py | 1034 +++++++++++++ src/twinkle/model/__init__.py | 1 + src/twinkle/model/megatron.py | 856 +++++++++++ src/twinkle/model/strategy/__init__.py | 1 + src/twinkle/model/strategy/megatron.py | 638 ++++++++ src/twinkle/model/transformers.py | 6 +- src/twinkle/processor/base.py | 8 +- src/twinkle/utils/framework.py | 19 + twinkle | 1 - 25 files changed, 5895 insertions(+), 7 deletions(-) create mode 100644 cookbook/megatron/__init__.py create mode 100644 cookbook/megatron/client.py create mode 100644 cookbook/megatron/lora.py create mode 100644 cookbook/megatron/server.py create mode 100644 src/twinkle/loss/vocab_parallel_cross_entropy.py create mode 100644 src/twinkle/megatron/__init__.py create mode 100644 src/twinkle/megatron/model/__init__.py create mode 100644 src/twinkle/megatron/model/bridge.py create mode 100644 src/twinkle/megatron/model/initializer.py create mode 100644 src/twinkle/megatron/model/qwen3.py create mode 100644 src/twinkle/megatron/tuners/__init__.py create mode 100644 src/twinkle/megatron/tuners/lora.py create mode 100644 src/twinkle/megatron/utils.py create mode 100644 src/twinkle/model/megatron.py create mode 100644 src/twinkle/model/strategy/megatron.py delete mode 120000 twinkle diff --git a/cookbook/megatron/__init__.py b/cookbook/megatron/__init__.py new file mode 100644 index 00000000..1da7257c --- /dev/null +++ b/cookbook/megatron/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) twinkle authors. All rights reserved. +"""Megatron training examples for twinkle.""" + diff --git a/cookbook/megatron/client.py b/cookbook/megatron/client.py new file mode 100644 index 00000000..862d11e3 --- /dev/null +++ b/cookbook/megatron/client.py @@ -0,0 +1,287 @@ +# Copyright (c) twinkle authors. All rights reserved. +"""Megatron LoRA training client. + +This client sends training requests to the Megatron model server. + +Usage: + # First start the server: + python cookbook/megatron/server.py --port 8000 + + # Then run the client: + python cookbook/megatron/client.py --server_url http://localhost:8000 +""" +import argparse +from typing import Any, Dict, Optional + +import requests + +from twinkle import get_logger +from twinkle.dataset import Dataset, DatasetMeta + +logger = get_logger() + + +class MegatronModelClient: + """Client for remote Megatron model training.""" + + def __init__(self, server_url: str, timeout: int = 300): + """Initialize client. + + Args: + server_url: URL of the model server. + timeout: Request timeout in seconds. + """ + self.server_url = server_url.rstrip('/') + self.timeout = timeout + + def _request(self, endpoint: str, method: str = 'POST', data: Dict = None) -> Dict: + """Send request to server. + + Args: + endpoint: API endpoint. + method: HTTP method. + data: Request data. + + Returns: + Response data. + """ + url = f'{self.server_url}/{endpoint}' + + try: + if method == 'GET': + response = requests.get(url, timeout=self.timeout) + else: + response = requests.post(url, json=data or {}, timeout=self.timeout) + + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + logger.error(f'Request failed: {e}') + return {'status': 'error', 'message': str(e)} + + def health_check(self) -> bool: + """Check if server is healthy. + + Returns: + True if server is healthy. + """ + result = self._request('health', method='GET') + return result.get('status') == 'healthy' + + def initialize_model( + self, + model_name: str, + lora_config: Optional[Dict[str, Any]] = None, + ) -> Dict: + """Initialize model on server. + + Args: + model_name: HuggingFace model name or path. + lora_config: Optional LoRA configuration. + + Returns: + Server response. + """ + return self._request('initialize', data={ + 'model_name': model_name, + 'lora_config': lora_config, + }) + + def set_optimizer(self, optimizer_type: str = 'AdamW', **kwargs) -> Dict: + """Set optimizer on server. + + Args: + optimizer_type: Optimizer type name. + **kwargs: Optimizer arguments. + + Returns: + Server response. + """ + return self._request('set_optimizer', data={ + 'optimizer_type': optimizer_type, + **kwargs, + }) + + def set_lr_scheduler(self, scheduler_type: str = 'CosineAnnealingLR', **kwargs) -> Dict: + """Set learning rate scheduler on server. + + Args: + scheduler_type: Scheduler type name. + **kwargs: Scheduler arguments. + + Returns: + Server response. + """ + return self._request('set_lr_scheduler', data={ + 'scheduler_type': scheduler_type, + **kwargs, + }) + + def train_step(self, batch: Dict[str, Any]) -> Dict: + """Execute one training step. + + Args: + batch: Input batch data. + + Returns: + Server response with loss. + """ + return self._request('train_step', data={'batch': batch}) + + def save_checkpoint(self, output_path: str) -> Dict: + """Save model checkpoint. + + Args: + output_path: Path to save checkpoint. + + Returns: + Server response. + """ + return self._request('save', data={'output_path': output_path}) + + def get_train_configs(self) -> Dict: + """Get training configuration from server. + + Returns: + Training configuration. + """ + return self._request('configs', method='GET') + + +def create_dataset(args): + """Create and preprocess dataset.""" + dataset = Dataset(dataset_meta=DatasetMeta(args.dataset)) + dataset.set_template('Qwen3Template', model_id=args.model_name) + dataset.map('CompetitionMathProcessor') + dataset.encode(batched=True) + return dataset + + +def parse_args(): + parser = argparse.ArgumentParser(description='Megatron Model Client') + + # Server arguments + parser.add_argument('--server_url', type=str, default='http://localhost:8000', + help='Model server URL') + parser.add_argument('--timeout', type=int, default=300, + help='Request timeout in seconds') + + # Model arguments + parser.add_argument('--model_name', type=str, default='ms://Qwen/Qwen2.5-7B-Instruct', + help='HuggingFace model name or path') + parser.add_argument('--output_dir', type=str, default='./output/megatron_lora', + help='Output directory for checkpoints') + + # LoRA arguments + parser.add_argument('--lora_rank', type=int, default=8, + help='LoRA rank') + parser.add_argument('--lora_alpha', type=int, default=32, + help='LoRA alpha') + parser.add_argument('--lora_dropout', type=float, default=0.05, + help='LoRA dropout') + parser.add_argument('--target_modules', type=str, default='all-linear', + help='Target modules for LoRA') + + # Training arguments + parser.add_argument('--batch_size', type=int, default=4, + help='Batch size') + parser.add_argument('--learning_rate', type=float, default=1e-4, + help='Learning rate') + parser.add_argument('--max_steps', type=int, default=1000, + help='Maximum training steps') + parser.add_argument('--save_steps', type=int, default=50, + help='Checkpoint save interval') + parser.add_argument('--log_steps', type=int, default=10, + help='Logging interval') + + # Dataset arguments + parser.add_argument('--dataset', type=str, default='ms://modelscope/competition_math', + help='Dataset name') + + return parser.parse_args() + + +def main(): + args = parse_args() + + # Create client + client = MegatronModelClient( + server_url=args.server_url, + timeout=args.timeout, + ) + + # Health check + if not client.health_check(): + logger.error('Server is not available') + return + + logger.info('Server is healthy, initializing model...') + + # Initialize model with LoRA + lora_config = { + 'r': args.lora_rank, + 'lora_alpha': args.lora_alpha, + 'lora_dropout': args.lora_dropout, + 'target_modules': args.target_modules, + } + + result = client.initialize_model( + model_name=args.model_name, + lora_config=lora_config, + ) + + if result.get('status') != 'success': + logger.error(f'Failed to initialize model: {result}') + return + + logger.info('Model initialized, setting optimizer...') + + # Set optimizer and scheduler + client.set_optimizer(optimizer_type='AdamW', lr=args.learning_rate, weight_decay=0.01) + client.set_lr_scheduler(scheduler_type='CosineAnnealingLR', T_max=args.max_steps) + + # Print training configuration + configs = client.get_train_configs() + logger.info(f'Training configs: {configs}') + + # Create dataset and dataloader + logger.info('Loading dataset...') + dataset = create_dataset(args) + + # Training loop + logger.info('Starting training...') + global_step = 0 + + for step, batch in enumerate(dataset.iter(batch_size=args.batch_size)): + if global_step >= args.max_steps: + break + + # Send batch to server for training + result = client.train_step(batch) + + if result.get('status') != 'success': + logger.error(f'Training step failed: {result}') + continue + + global_step += 1 + + # Log progress + if global_step % args.log_steps == 0: + loss = result.get('loss', 'N/A') + logger.info(f'Step {global_step}, Loss: {loss}') + + # Save checkpoint + if global_step % args.save_steps == 0: + save_result = client.save_checkpoint( + f'{args.output_dir}/checkpoint-{global_step}' + ) + logger.info(f'Checkpoint saved: {save_result}') + + # Save final model + client.save_checkpoint(args.output_dir) + logger.info(f'Training completed. Model saved to {args.output_dir}') + + +if __name__ == '__main__': + main() + diff --git a/cookbook/megatron/lora.py b/cookbook/megatron/lora.py new file mode 100644 index 00000000..589b1448 --- /dev/null +++ b/cookbook/megatron/lora.py @@ -0,0 +1,217 @@ +# Copyright (c) twinkle authors. All rights reserved. +"""Megatron-Core LoRA training example. + +This example demonstrates LoRA fine-tuning using Megatron-Core backend. +Supports both local (DDP) and Ray distributed modes. + +Usage (local mode with 4 GPUs): + torchrun --nproc_per_node=4 cookbook/megatron/lora.py --mode local + +Usage (Ray mode): + python cookbook/megatron/lora.py --mode ray +""" +import argparse + +import numpy as np +from peft import LoraConfig +from torch.optim import AdamW +from torch.optim.lr_scheduler import CosineAnnealingLR + +import twinkle +from twinkle import get_device_placement, get_logger, DeviceMesh, DeviceGroup, Platform +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.loss import CrossEntropyLoss, MegatronCrossEntropyLoss +from twinkle.model import MegatronModel +from twinkle.processor import InputProcessor + +logger = get_logger() + + +def parse_args(): + parser = argparse.ArgumentParser(description='Megatron LoRA Training') + + # Mode selection + parser.add_argument('--mode', type=str, default='ray', + choices=['local', 'ray'], + help='Distributed mode: local (DDP) or ray') + + # Model arguments + parser.add_argument('--model_name', type=str, default='ms://Qwen/Qwen2.5-7B-Instruct', + help='HuggingFace model name or path') + parser.add_argument('--output_dir', type=str, default='./output/megatron_lora', + help='Output directory for checkpoints') + + # Parallelism arguments + parser.add_argument('--nproc_per_node', type=int, default=4, + help='Number of processes per node') + parser.add_argument('--tp_size', type=int, default=2, + help='Tensor parallel size') + parser.add_argument('--dp_size', type=int, default=2, + help='Data parallel size') + parser.add_argument('--sequence_parallel', action='store_true', + help='Enable sequence parallelism') + parser.add_argument('--mixed_precision', type=str, default='bf16', + choices=['no', 'fp16', 'bf16'], + help='Mixed precision mode') + + # LoRA arguments + parser.add_argument('--lora_rank', type=int, default=8, + help='LoRA rank') + parser.add_argument('--lora_alpha', type=int, default=32, + help='LoRA alpha') + parser.add_argument('--lora_dropout', type=float, default=0.05, + help='LoRA dropout') + parser.add_argument('--target_modules', type=str, default='all-linear', + help='Target modules for LoRA') + + # Training arguments + parser.add_argument('--batch_size', type=int, default=4, + help='Batch size per GPU') + parser.add_argument('--gradient_accumulation_steps', type=int, default=16, + help='Gradient accumulation steps') + parser.add_argument('--learning_rate', type=float, default=1e-4, + help='Learning rate') + parser.add_argument('--max_grad_norm', type=float, default=1.0, + help='Maximum gradient norm for clipping') + parser.add_argument('--max_steps', type=int, default=1000, + help='Maximum training steps') + parser.add_argument('--save_steps', type=int, default=50, + help='Checkpoint save interval') + + # Dataset arguments + parser.add_argument('--dataset', type=str, default='ms://modelscope/competition_math', + help='Dataset name') + + return parser.parse_args() + + +def create_device_mesh(args) -> DeviceMesh: + """Create device mesh for Megatron parallelism.""" + # For Megatron: mesh shape is (dp, tp) + # dp_size * tp_size = nproc_per_node + mesh = np.arange(args.nproc_per_node).reshape(args.dp_size, args.tp_size) + + device_mesh = DeviceMesh( + device_type='cuda', + mesh=mesh, + mesh_dim_names=('dp', 'tp'), + ) + return device_mesh + + +def create_device_group(args): + """Create device group for model placement.""" + device_group = [ + DeviceGroup( + name='model', + ranks=list(range(args.nproc_per_node)), + device_type=Platform.get_platform().device_prefix(), + ) + ] + return device_group + + +def create_dataset(args): + """Create and preprocess dataset.""" + dataset = Dataset(dataset_meta=DatasetMeta(args.dataset)) + dataset.set_template('Qwen3Template', model_id=args.model_name) + dataset.map('CompetitionMathProcessor') + dataset.encode(batched=True) + return dataset + + +def train(args): + """Main training function.""" + # Create dataloader + dataloader = DataLoader( + dataset=lambda: create_dataset(args), + batch_size=args.batch_size, + ) + + # Create Megatron model + model = MegatronModel( + pretrained_model_name_or_path=args.model_name, + tensor_model_parallel_size=args.tp_size, + sequence_parallel=args.sequence_parallel, + mixed_precision=args.mixed_precision, + ) + + # Configure LoRA adapter + lora_config = LoraConfig( + r=args.lora_rank, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + target_modules=args.target_modules, + ) + + model.add_adapter_to_model( + 'default', + lora_config, + gradient_accumulation_steps=args.gradient_accumulation_steps, + ) + + # Set template and processor + model.set_template('Qwen3Template') + model.set_processor(InputProcessor, padding_side='right') + + # Set loss, optimizer, scheduler + model.set_loss(MegatronCrossEntropyLoss) + model.set_optimizer(AdamW, lr=args.learning_rate, weight_decay=0.01) + model.set_lr_scheduler(CosineAnnealingLR, T_max=args.max_steps) + + # Print training configuration + logger.info(get_device_placement()) + logger.info(model.get_train_configs()) + + # Training loop + global_step = 0 + for step, batch in enumerate(dataloader): + if global_step >= args.max_steps: + break + + # Forward-backward pass + output = model.forward_backward(inputs=batch) + + # Log loss at gradient accumulation boundary + if step % args.gradient_accumulation_steps == 0: + logger.info(f'Step {global_step}, Loss: {output}') + global_step += 1 + + # Gradient clipping and optimizer step + model.clip_grad_norm(args.max_grad_norm) + model.step() + model.zero_grad() + model.lr_step() + + # Save checkpoint + if global_step > 0 and global_step % args.save_steps == 0: + model.save(f'{args.output_dir}/checkpoint-{global_step}') + + # Save final model + model.save(args.output_dir) + logger.info(f'Model saved to {args.output_dir}') + + +def main(): + args = parse_args() + + # Create device mesh and group + device_mesh = create_device_mesh(args) + device_group = create_device_group(args) + + # Initialize twinkle with specified mode + twinkle.initialize( + mode=args.mode, + nproc_per_node=args.nproc_per_node, + groups=device_group, + global_device_mesh=device_mesh, + lazy_collect=False, + ) + + # Start training + train(args) + + +if __name__ == '__main__': + main() diff --git a/cookbook/megatron/server.py b/cookbook/megatron/server.py new file mode 100644 index 00000000..71255028 --- /dev/null +++ b/cookbook/megatron/server.py @@ -0,0 +1,270 @@ +# Copyright (c) twinkle authors. All rights reserved. +"""Megatron LoRA training server. + +This server hosts the Megatron model and handles training requests from clients. + +Usage: + python cookbook/megatron/server.py --port 8000 --tp_size 2 +""" +import argparse +from typing import Any, Dict + +import numpy as np + +import twinkle +from twinkle import get_logger, DeviceMesh, DeviceGroup, Platform +from twinkle.model import MegatronModel +from twinkle.loss import CrossEntropyLoss +from twinkle.processor import InputProcessor + +logger = get_logger() + + +class MegatronModelServer: + """Server wrapper for Megatron model.""" + + def __init__(self, args): + self.args = args + self.model = None + self.is_initialized = False + + def initialize_model(self, model_name: str, lora_config: Dict[str, Any] = None): + """Initialize the Megatron model with optional LoRA configuration. + + Args: + model_name: HuggingFace model name or path. + lora_config: Optional LoRA configuration dict. + """ + logger.info(f'Initializing model: {model_name}') + + self.model = MegatronModel( + pretrained_model_name_or_path=model_name, + tensor_model_parallel_size=self.args.tp_size, + sequence_parallel=self.args.sequence_parallel, + mixed_precision=self.args.mixed_precision, + ) + + if lora_config: + from peft import LoraConfig + config = LoraConfig(**lora_config) + self.model.add_adapter_to_model( + 'default', + config, + gradient_accumulation_steps=self.args.gradient_accumulation_steps, + ) + + self.model.set_template('Qwen3Template') + self.model.set_processor(InputProcessor, padding_side='right') + self.model.set_loss(CrossEntropyLoss) + + self.is_initialized = True + logger.info('Model initialized successfully') + + return {'status': 'success', 'message': 'Model initialized'} + + def set_optimizer(self, optimizer_type: str = 'AdamW', **kwargs): + """Set optimizer for the model.""" + if not self.is_initialized: + return {'status': 'error', 'message': 'Model not initialized'} + + from torch.optim import AdamW, SGD + optimizer_map = {'AdamW': AdamW, 'SGD': SGD} + + if optimizer_type not in optimizer_map: + return {'status': 'error', 'message': f'Unknown optimizer: {optimizer_type}'} + + self.model.set_optimizer(optimizer_map[optimizer_type], **kwargs) + return {'status': 'success', 'message': f'Optimizer {optimizer_type} set'} + + def set_lr_scheduler(self, scheduler_type: str = 'CosineAnnealingLR', **kwargs): + """Set learning rate scheduler.""" + if not self.is_initialized: + return {'status': 'error', 'message': 'Model not initialized'} + + from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, StepLR + scheduler_map = { + 'CosineAnnealingLR': CosineAnnealingLR, + 'LinearLR': LinearLR, + 'StepLR': StepLR, + } + + if scheduler_type not in scheduler_map: + return {'status': 'error', 'message': f'Unknown scheduler: {scheduler_type}'} + + self.model.set_lr_scheduler(scheduler_map[scheduler_type], **kwargs) + return {'status': 'success', 'message': f'Scheduler {scheduler_type} set'} + + def train_step(self, batch: Dict[str, Any]): + """Execute one training step. + + Args: + batch: Input batch data. + + Returns: + Training step result with loss. + """ + if not self.is_initialized: + return {'status': 'error', 'message': 'Model not initialized'} + + # Forward-backward pass + loss = self.model.forward_backward(inputs=batch) + + # Optimizer step + self.model.clip_grad_norm(self.args.max_grad_norm) + self.model.step() + self.model.zero_grad() + self.model.lr_step() + + return {'status': 'success', 'loss': float(loss) if loss else None} + + def save_checkpoint(self, output_path: str): + """Save model checkpoint. + + Args: + output_path: Path to save checkpoint. + """ + if not self.is_initialized: + return {'status': 'error', 'message': 'Model not initialized'} + + self.model.save(output_path) + return {'status': 'success', 'message': f'Checkpoint saved to {output_path}'} + + def get_train_configs(self): + """Get current training configuration.""" + if not self.is_initialized: + return {'status': 'error', 'message': 'Model not initialized'} + + return {'status': 'success', 'configs': self.model.get_train_configs()} + + +def create_device_mesh(args) -> DeviceMesh: + """Create device mesh for Megatron parallelism.""" + mesh = np.arange(args.nproc_per_node).reshape(args.dp_size, args.tp_size) + + device_mesh = DeviceMesh( + device_type='cuda', + mesh=mesh, + mesh_dim_names=('dp', 'tp'), + ) + return device_mesh + + +def create_device_group(args): + """Create device group for model placement.""" + device_group = [ + DeviceGroup( + name='model', + ranks=list(range(args.nproc_per_node)), + device_type=Platform.get_platform().device_prefix(), + ) + ] + return device_group + + +def parse_args(): + parser = argparse.ArgumentParser(description='Megatron Model Server') + + # Server arguments + parser.add_argument('--host', type=str, default='0.0.0.0', + help='Server host') + parser.add_argument('--port', type=int, default=8000, + help='Server port') + + # Parallelism arguments + parser.add_argument('--nproc_per_node', type=int, default=4, + help='Number of processes per node') + parser.add_argument('--tp_size', type=int, default=2, + help='Tensor parallel size') + parser.add_argument('--dp_size', type=int, default=2, + help='Data parallel size') + parser.add_argument('--sequence_parallel', action='store_true', + help='Enable sequence parallelism') + parser.add_argument('--mixed_precision', type=str, default='bf16', + choices=['no', 'fp16', 'bf16'], + help='Mixed precision mode') + + # Training defaults + parser.add_argument('--gradient_accumulation_steps', type=int, default=16, + help='Gradient accumulation steps') + parser.add_argument('--max_grad_norm', type=float, default=1.0, + help='Maximum gradient norm for clipping') + + return parser.parse_args() + + +def main(): + args = parse_args() + + # Initialize distributed environment + device_mesh = create_device_mesh(args) + device_group = create_device_group(args) + + twinkle.initialize( + mode='local', + nproc_per_node=args.nproc_per_node, + groups=device_group, + global_device_mesh=device_mesh, + lazy_collect=False, + ) + + # Create model server + server = MegatronModelServer(args) + + # Start HTTP server + try: + from flask import Flask, request, jsonify + except ImportError: + logger.error('Flask not installed. Install with: pip install flask') + return + + app = Flask(__name__) + + @app.route('/health', methods=['GET']) + def health(): + return jsonify({'status': 'healthy'}) + + @app.route('/initialize', methods=['POST']) + def initialize(): + data = request.json + result = server.initialize_model( + model_name=data.get('model_name'), + lora_config=data.get('lora_config'), + ) + return jsonify(result) + + @app.route('/set_optimizer', methods=['POST']) + def set_optimizer(): + data = request.json + result = server.set_optimizer(**data) + return jsonify(result) + + @app.route('/set_lr_scheduler', methods=['POST']) + def set_lr_scheduler(): + data = request.json + result = server.set_lr_scheduler(**data) + return jsonify(result) + + @app.route('/train_step', methods=['POST']) + def train_step(): + data = request.json + result = server.train_step(batch=data.get('batch', {})) + return jsonify(result) + + @app.route('/save', methods=['POST']) + def save(): + data = request.json + result = server.save_checkpoint(output_path=data.get('output_path')) + return jsonify(result) + + @app.route('/configs', methods=['GET']) + def configs(): + result = server.get_train_configs() + return jsonify(result) + + logger.info(f'Starting server on {args.host}:{args.port}') + app.run(host=args.host, port=args.port, threaded=False) + + +if __name__ == '__main__': + main() + diff --git a/cookbook/sft/lora.py b/cookbook/sft/lora.py index e5afffce..c8803fc2 100644 --- a/cookbook/sft/lora.py +++ b/cookbook/sft/lora.py @@ -35,7 +35,7 @@ # mesh_dim_names=('dp',) #) -twinkle.initialize(mode='ray', groups=device_group, global_device_mesh=device_mesh, lazy_collect=False) +twinkle.initialize(mode='ray', nproc_per_node=4, groups=device_group, global_device_mesh=device_mesh, lazy_collect=False) def create_dataset(): diff --git a/src/twinkle/infra/__init__.py b/src/twinkle/infra/__init__.py index 19a7ac46..ae368f7c 100644 --- a/src/twinkle/infra/__init__.py +++ b/src/twinkle/infra/__init__.py @@ -217,7 +217,7 @@ def render_mesh_grid(mesh_array, dim_names): lines.extend(section_bottom()) lines.append("") - return "\n".join(lines) + return "\n" + "\n".join(lines) def _get_workers(workers, execute): diff --git a/src/twinkle/infra/ray/resource_manager.py b/src/twinkle/infra/ray/resource_manager.py index 3e38ede4..13609670 100644 --- a/src/twinkle/infra/ray/resource_manager.py +++ b/src/twinkle/infra/ray/resource_manager.py @@ -96,12 +96,16 @@ def __init__(self, self.device_groups = {} ray_address = str(ray.get_runtime_context().gcs_address) + min_rank = min(all_ranks) if all_ranks else 0 for group in groups: if device_type != 'CPU': ranks = group.ranks local_device_groups = [] for rank in ranks: - node_rank = rank // nproc_per_node + # Normalize rank by subtracting min_rank for node calculation + normalized_rank = rank - min_rank + node_rank = normalized_rank // nproc_per_node + # Use original rank for gpu_rank to support non-zero starting ranks gpu_rank = rank % nproc_per_node local_device_groups.append( dict( diff --git a/src/twinkle/loss/__init__.py b/src/twinkle/loss/__init__.py index f8eb226c..e223a6ed 100644 --- a/src/twinkle/loss/__init__.py +++ b/src/twinkle/loss/__init__.py @@ -10,6 +10,7 @@ from .listwise_reranker import ListwiseRerankerLoss from .listwise_generative_reranker import ListwiseGenerativeRerankerLoss from .grpo import GRPOLoss +from .vocab_parallel_cross_entropy import VocabParallelCrossEntropyLoss, MegatronCrossEntropyLoss from .base import Loss torch_loss_mapping = { @@ -25,4 +26,6 @@ 'listwise_reranker': ListwiseRerankerLoss, 'listwise_generative_reranker': ListwiseGenerativeRerankerLoss, 'grpo': GRPOLoss, + 'vocab_parallel_cross_entropy': VocabParallelCrossEntropyLoss, + 'megatron_cross_entropy': MegatronCrossEntropyLoss, } \ No newline at end of file diff --git a/src/twinkle/loss/vocab_parallel_cross_entropy.py b/src/twinkle/loss/vocab_parallel_cross_entropy.py new file mode 100644 index 00000000..16dc03aa --- /dev/null +++ b/src/twinkle/loss/vocab_parallel_cross_entropy.py @@ -0,0 +1,87 @@ +# Copyright (c) twinkle authors. All rights reserved. +"""Vocabulary-parallel cross entropy loss for Megatron TP training. + +When using Tensor Parallelism, the vocabulary dimension is sharded across TP ranks. +Standard CrossEntropyLoss will fail because labels may fall outside the local +vocab partition. This module provides vocab-parallel loss computation. +""" +from typing import Optional + +import torch +import torch.nn.functional as F + +from .base import Loss + +try: + from megatron.core import parallel_state as mpu + from megatron.core import tensor_parallel + MEGATRON_AVAILABLE = True +except ImportError: + MEGATRON_AVAILABLE = False + + +class VocabParallelCrossEntropyLoss(Loss): + """Cross entropy loss that handles vocabulary parallelism in Megatron. + + When using TP (Tensor Parallelism), the vocabulary is sharded across TP ranks. + This loss uses Megatron's vocab_parallel_cross_entropy which correctly handles + the distributed computation. + + Fallback: When Megatron is not available or TP=1, uses standard CrossEntropyLoss. + """ + + def __call__(self, inputs, outputs, **kwargs): + logits = outputs['logits'] + labels = inputs['labels'] + + # Get dimensions + # logits: [batch, seq, vocab] or [batch, seq, partition_vocab] + # labels: [batch, seq] + + if not MEGATRON_AVAILABLE: + # Fallback to standard loss + logits_2d = logits.view(-1, logits.shape[-1]) + labels_1d = labels.view(-1) + return F.cross_entropy(logits_2d, labels_1d, ignore_index=-100) + + tp_size = mpu.get_tensor_model_parallel_world_size() + + if tp_size == 1: + # No TP, use standard cross entropy + logits_2d = logits.view(-1, logits.shape[-1]) + labels_1d = labels.view(-1) + return F.cross_entropy(logits_2d, labels_1d, ignore_index=-100) + + # Use Megatron's vocab-parallel cross entropy + # Megatron expects [seq, batch, vocab] format for logits + # and [seq, batch] for labels + + # Transpose logits: [batch, seq, vocab] -> [seq, batch, vocab] + logits_sbv = logits.transpose(0, 1).contiguous() + + # Transpose labels: [batch, seq] -> [seq, batch] + # Must be contiguous for Megatron's view() operations + labels_sb = labels.transpose(0, 1).contiguous() + + # Megatron's vocab_parallel_cross_entropy handles the TP sharding correctly + # It returns per-token loss of shape [seq, batch] + per_token_loss = tensor_parallel.vocab_parallel_cross_entropy(logits_sbv, labels_sb) + + # Transpose back: [seq, batch] -> [batch, seq] + per_token_loss = per_token_loss.transpose(0, 1).contiguous() + + # Apply loss mask (ignore labels == -100) + loss_mask = (labels != -100).float() + + # Compute mean loss (only over non-masked positions) + loss = (per_token_loss * loss_mask).sum() / loss_mask.sum().clamp(min=1) + + return loss + + +class MegatronCrossEntropyLoss(VocabParallelCrossEntropyLoss): + """Alias for VocabParallelCrossEntropyLoss. + + Use this when training with Megatron backend and TP > 1. + """ + pass diff --git a/src/twinkle/megatron/__init__.py b/src/twinkle/megatron/__init__.py new file mode 100644 index 00000000..b5ccdf62 --- /dev/null +++ b/src/twinkle/megatron/__init__.py @@ -0,0 +1,107 @@ +# Copyright (c) twinkle authors. All rights reserved. +"""Megatron-Core integration for twinkle training framework. + +This module provides independent implementation for Megatron support, +without external dependencies on swift's GPTBridge. +""" + +from .tuners import LoraParallelLinear, dispatch_megatron +from .utils import ( + # Layer finding + find_all_linears, + find_router, + find_embedding, + get_target_modules, + set_linear_is_expert, + # Model preparation + prepare_mcore_model, + prepare_lora_model, + # Config conversion + convert_hf_config, + # Utilities + get_model_parameter_info, + get_padding_to, + patch_deepcopy, + tuners_sharded_state_dict, + forward_step_helper, + deep_getattr, + # Multi-tenant support + TenantProcessGroupManager, + get_tenant_manager, + # Training state + MegatronTrainerState, +) +from .model import ( + # Bridge classes + TwinkleBridgeAdapter, + TwinkleGPTBridge, + BridgeConfig, + SafetensorLoader, + StreamingSafetensorSaver, + LazyTensor, + # Helper functions + load_hf_weights_to_megatron, + is_last_rank, + deep_getattr as bridge_deep_getattr, # Avoid conflict with utils.deep_getattr + # Legacy compatibility + create_megatron_args, + set_megatron_args, + restore_megatron_args, + mock_megatron_args, + # Initializer + MegatronModelInitializer, + initialize_megatron_model, + # Qwen3 support + Qwen3ModelMeta, + get_model_default_config, +) + +__all__ = [ + # Tuners + 'LoraParallelLinear', + 'dispatch_megatron', + # Layer finding + 'find_all_linears', + 'find_router', + 'find_embedding', + 'get_target_modules', + 'set_linear_is_expert', + # Model preparation + 'prepare_mcore_model', + 'prepare_lora_model', + # Config conversion + 'convert_hf_config', + # Utilities + 'get_model_parameter_info', + 'get_padding_to', + 'patch_deepcopy', + 'tuners_sharded_state_dict', + 'forward_step_helper', + 'deep_getattr', + # Multi-tenant support + 'TenantProcessGroupManager', + 'get_tenant_manager', + # Training state + 'MegatronTrainerState', + # Bridge classes + 'TwinkleBridgeAdapter', + 'TwinkleGPTBridge', + 'BridgeConfig', + 'SafetensorLoader', + 'StreamingSafetensorSaver', + 'LazyTensor', + # Helper functions + 'load_hf_weights_to_megatron', + 'is_last_rank', + # Legacy compatibility + 'create_megatron_args', + 'set_megatron_args', + 'restore_megatron_args', + 'mock_megatron_args', + # Initializer + 'MegatronModelInitializer', + 'initialize_megatron_model', + # Qwen3 support + 'Qwen3ModelMeta', + 'get_model_default_config', +] diff --git a/src/twinkle/megatron/model/__init__.py b/src/twinkle/megatron/model/__init__.py new file mode 100644 index 00000000..d421c70b --- /dev/null +++ b/src/twinkle/megatron/model/__init__.py @@ -0,0 +1,52 @@ +# Copyright (c) twinkle authors. All rights reserved. +"""Megatron model initialization and weight conversion. + +This module provides independent implementation for weight loading/saving, +without external dependencies on swift. +""" + +from .bridge import ( + # Main classes + TwinkleBridgeAdapter, + TwinkleGPTBridge, + BridgeConfig, + SafetensorLoader, + StreamingSafetensorSaver, + LazyTensor, + # Helper functions + deep_getattr, + is_last_rank, + load_hf_weights_to_megatron, + # Legacy compatibility + create_megatron_args, + set_megatron_args, + restore_megatron_args, + mock_megatron_args, +) +from .initializer import MegatronModelInitializer, initialize_megatron_model +from .qwen3 import Qwen3ModelMeta, get_model_default_config + +__all__ = [ + # Bridge classes + 'TwinkleBridgeAdapter', + 'TwinkleGPTBridge', + 'BridgeConfig', + 'SafetensorLoader', + 'StreamingSafetensorSaver', + 'LazyTensor', + # Helper functions + 'deep_getattr', + 'is_last_rank', + 'load_hf_weights_to_megatron', + # Legacy compatibility + 'create_megatron_args', + 'set_megatron_args', + 'restore_megatron_args', + 'mock_megatron_args', + # Initializer + 'MegatronModelInitializer', + 'initialize_megatron_model', + # Model metadata + 'Qwen3ModelMeta', + 'get_model_default_config', +] diff --git a/src/twinkle/megatron/model/bridge.py b/src/twinkle/megatron/model/bridge.py new file mode 100644 index 00000000..33a56356 --- /dev/null +++ b/src/twinkle/megatron/model/bridge.py @@ -0,0 +1,1298 @@ +# Copyright (c) twinkle authors. All rights reserved. +# GPT Bridge for HuggingFace to Megatron-Core weight conversion. +# This implementation is adapted from ms-swift's GPTBridge. +"""Weight conversion bridge between HuggingFace and Megatron-Core formats. + +This module provides independent implementation for weight loading/saving, +adapted from swift's GPTBridge but without external dependencies. + +Supports: +- Qwen2.5 / Qwen3 model families +- PEFT/LoRA format loading and saving +- Tensor Parallel / Pipeline Parallel weight sharding +- MoE (Mixture of Experts) models +""" +from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union +from types import SimpleNamespace +from dataclasses import dataclass, field +from copy import copy +import os +import json +import glob +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +from tqdm import tqdm + +try: + from megatron.core import parallel_state as mpu + MEGATRON_AVAILABLE = True +except ImportError: + MEGATRON_AVAILABLE = False + mpu = None + +try: + from safetensors import safe_open + from safetensors.torch import save_file + SAFETENSORS_AVAILABLE = True +except ImportError: + SAFETENSORS_AVAILABLE = False + + +def deep_getattr(obj, attr: str, default=None): + """Get nested attribute from object using dot notation.""" + try: + for key in attr.split('.'): + obj = getattr(obj, key) + return obj + except AttributeError: + return default + + +def is_last_rank() -> bool: + """Check if current process is the last rank.""" + if not dist.is_initialized(): + return True + return dist.get_rank() == dist.get_world_size() - 1 + + +class LazyTensor: + """Lazy tensor wrapper for deferred loading.""" + + def __init__(self, loader, key: str): + self._loader = loader + self._key = key + + def load(self) -> torch.Tensor: + """Load the tensor.""" + return self._loader.get_tensor(self._key) + + +class SafetensorLoader: + """Lazy loader for safetensor files.""" + + def __init__(self, model_dir: str, is_peft_format: bool = False): + self.model_dir = model_dir + self.is_peft_format = is_peft_format + self._handles = {} + self._index = None + self._key_to_file = {} + self._load_index() + + def _load_index(self): + """Load safetensor index file if exists.""" + # Try adapter format first for PEFT + if self.is_peft_format: + adapter_file = os.path.join(self.model_dir, 'adapter_model.safetensors') + if os.path.exists(adapter_file): + handle = safe_open(adapter_file, framework='pt', device='cpu') + for key in handle.keys(): + self._key_to_file[key] = adapter_file + self._handles[adapter_file] = handle + return + + # Standard index file + index_file = os.path.join(self.model_dir, 'model.safetensors.index.json') + if os.path.exists(index_file): + with open(index_file, 'r') as f: + self._index = json.load(f) + for key, filename in self._index['weight_map'].items(): + self._key_to_file[key] = os.path.join(self.model_dir, filename) + else: + # Single file model + single_file = os.path.join(self.model_dir, 'model.safetensors') + if os.path.exists(single_file): + handle = safe_open(single_file, framework='pt', device='cpu') + for key in handle.keys(): + self._key_to_file[key] = single_file + self._handles[single_file] = handle + else: + # Try to find any safetensor file + files = glob.glob(os.path.join(self.model_dir, '*.safetensors')) + for filepath in files: + handle = safe_open(filepath, framework='pt', device='cpu') + for key in handle.keys(): + self._key_to_file[key] = filepath + self._handles[filepath] = handle + + def _get_handle(self, filepath: str): + """Get or create file handle.""" + if filepath not in self._handles: + self._handles[filepath] = safe_open(filepath, framework='pt', device='cpu') + return self._handles[filepath] + + def get_tensor(self, key: str) -> torch.Tensor: + """Load a single tensor.""" + filepath = self._key_to_file.get(key) + if filepath is None: + raise KeyError(f"Tensor key not found: {key}") + handle = self._get_handle(filepath) + return handle.get_tensor(key) + + def get_lazy(self, key: str) -> LazyTensor: + """Get a lazy tensor reference.""" + if key not in self._key_to_file: + raise KeyError(f"Tensor key not found: {key}") + return LazyTensor(self, key) + + def get_state_dict(self) -> Dict[str, LazyTensor]: + """Get lazy state dict.""" + return {key: LazyTensor(self, key) for key in self._key_to_file} + + def keys(self) -> List[str]: + """Get all tensor keys.""" + return list(self._key_to_file.keys()) + + def __contains__(self, key: str) -> bool: + return key in self._key_to_file + + def close(self): + """Close all file handles.""" + self._handles.clear() + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + +class StreamingSafetensorSaver: + """Streaming saver for safetensor files.""" + + def __init__(self, save_dir: str, max_shard_size: str = '5GB', is_peft_format: bool = False): + self.save_dir = save_dir + self.is_peft_format = is_peft_format + os.makedirs(save_dir, exist_ok=True) + + # Parse max shard size + size_str = max_shard_size.upper() + if size_str.endswith('GB'): + self.max_shard_bytes = int(float(size_str[:-2]) * 1024 ** 3) + elif size_str.endswith('MB'): + self.max_shard_bytes = int(float(size_str[:-2]) * 1024 ** 2) + else: + self.max_shard_bytes = int(size_str) + + self.current_shard = {} + self.current_shard_size = 0 + self.shard_idx = 1 + self.weight_map = {} + + def add_tensor(self, key: str, tensor: torch.Tensor): + """Add tensor to the current shard.""" + if tensor is None: + return + + tensor_size = tensor.numel() * tensor.element_size() + + # Flush if needed + if self.current_shard_size + tensor_size > self.max_shard_bytes and self.current_shard: + self._flush_shard() + + self.current_shard[key] = tensor.contiguous() + self.current_shard_size += tensor_size + + def _flush_shard(self): + """Flush current shard to disk.""" + if not self.current_shard: + return + + if self.is_peft_format: + filename = 'adapter_model.safetensors' + else: + filename = f'model-{self.shard_idx:05d}-of-XXXXX.safetensors' + + filepath = os.path.join(self.save_dir, filename) + save_file(self.current_shard, filepath) + + for key in self.current_shard: + self.weight_map[key] = filename + + self.current_shard = {} + self.current_shard_size = 0 + self.shard_idx += 1 + + def finalize(self): + """Finalize and write index.""" + self._flush_shard() + + if self.is_peft_format: + return # PEFT format doesn't need index + + # Fix shard filenames + total_shards = self.shard_idx - 1 + if total_shards == 0: + return + + for old_name in list(self.weight_map.values()): + new_name = old_name.replace('XXXXX', f'{total_shards:05d}') + if old_name != new_name: + old_path = os.path.join(self.save_dir, old_name) + new_path = os.path.join(self.save_dir, new_name) + if os.path.exists(old_path): + os.rename(old_path, new_path) + for key in self.weight_map: + if self.weight_map[key] == old_name: + self.weight_map[key] = new_name + + if total_shards > 1: + index = { + 'metadata': {'total_size': sum(t.numel() * t.element_size() + for t in self.current_shard.values())}, + 'weight_map': self.weight_map + } + with open(os.path.join(self.save_dir, 'model.safetensors.index.json'), 'w') as f: + json.dump(index, f, indent=2) + + +@dataclass +class BridgeConfig: + """Configuration for GPTBridge.""" + # Parallelism + tp_size: int = 1 + pp_size: int = 1 + ep_size: int = 1 + etp_size: int = 1 + + # Model architecture + hidden_size: int = 4096 + num_attention_heads: int = 32 + num_key_value_heads: int = 32 + num_layers: int = 32 + vocab_size: int = 32000 + padded_vocab_size: int = 32000 + intermediate_size: int = 11008 + + # Options + add_qkv_bias: bool = False + add_bias_linear: bool = False + qk_layernorm: bool = False + tie_word_embeddings: bool = False + + # MoE + num_experts: int = 0 + num_experts_per_tok: int = 2 + shared_expert_intermediate_size: int = 0 + + model_type: str = 'qwen2' + max_shard_size: str = '5GB' + + @classmethod + def from_hf_config( + cls, + hf_config: Any, + tp_size: int = 1, + pp_size: int = 1, + ep_size: int = 1, + padded_vocab_size: Optional[int] = None, + ) -> 'BridgeConfig': + """Create BridgeConfig from HuggingFace config.""" + vocab_size = getattr(hf_config, 'vocab_size', 32000) + if padded_vocab_size is None: + padded_vocab_size = vocab_size + # Pad to multiple of 64 for efficiency + if padded_vocab_size % 64 != 0: + padded_vocab_size = ((padded_vocab_size // 64) + 1) * 64 + + num_attention_heads = getattr(hf_config, 'num_attention_heads', 32) + num_key_value_heads = getattr(hf_config, 'num_key_value_heads', num_attention_heads) + + # MoE config + num_experts = getattr(hf_config, 'num_experts', 0) or \ + getattr(hf_config, 'n_routed_experts', 0) or \ + getattr(hf_config, 'num_local_experts', 0) + num_experts_per_tok = getattr(hf_config, 'num_experts_per_tok', 2) or \ + getattr(hf_config, 'moe_topk', 2) + shared_expert_size = getattr(hf_config, 'shared_expert_intermediate_size', 0) + + return cls( + tp_size=tp_size, + pp_size=pp_size, + ep_size=ep_size, + etp_size=tp_size, + hidden_size=getattr(hf_config, 'hidden_size', 4096), + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + num_layers=getattr(hf_config, 'num_hidden_layers', 32), + vocab_size=vocab_size, + padded_vocab_size=padded_vocab_size, + intermediate_size=getattr(hf_config, 'intermediate_size', 11008), + add_qkv_bias=getattr(hf_config, 'attention_bias', False), + add_bias_linear=getattr(hf_config, 'mlp_bias', False), + qk_layernorm=getattr(hf_config, 'qk_layernorm', False) or \ + getattr(hf_config, 'use_qk_norm', False), + tie_word_embeddings=getattr(hf_config, 'tie_word_embeddings', False), + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + shared_expert_intermediate_size=shared_expert_size, + model_type=getattr(hf_config, 'model_type', 'qwen2'), + ) + + +class TwinkleGPTBridge: + """Bridge for converting weights between HuggingFace and Megatron-Core formats. + + Adapted from swift's GPTBridge implementation. + Supports Qwen2.5 / Qwen3 model families. + """ + + # HuggingFace model structure constants (Qwen2/Qwen3 compatible) + HF_LAYERS_PREFIX = 'model.layers' + HF_EMBED_KEY = 'model.embed_tokens.weight' + HF_FINAL_LAYERNORM_KEY = 'model.norm.weight' + HF_LM_HEAD_KEY = 'lm_head.weight' + + def __init__(self, config: BridgeConfig, hf_config: Any = None, disable_tqdm: bool = False): + """Initialize the bridge. + + Args: + config: Bridge configuration. + hf_config: HuggingFace model config (for reference). + disable_tqdm: Whether to disable progress bar. + """ + self.config = config + self.hf_config = hf_config + self.disable_tqdm = disable_tqdm or not is_last_rank() + + # Parallel state + self.tp_size = config.tp_size + self.pp_size = config.pp_size + self.ep_size = config.ep_size + self.etp_size = config.etp_size + + # Get parallel ranks + if MEGATRON_AVAILABLE and mpu.is_initialized(): + self.tp_rank = mpu.get_tensor_model_parallel_rank() + self.pp_rank = mpu.get_pipeline_model_parallel_rank() + self.tp_group = mpu.get_tensor_model_parallel_group() + self.pp_group = mpu.get_pipeline_model_parallel_group() + try: + self.ep_rank = mpu.get_expert_model_parallel_rank() + self.ep_group = mpu.get_expert_model_parallel_group() + self.etp_rank = mpu.get_expert_tensor_parallel_rank() + self.etp_group = mpu.get_expert_tensor_parallel_group() + except: + self.ep_rank = 0 + self.ep_group = None + self.etp_rank = 0 + self.etp_group = None + else: + self.tp_rank = 0 + self.pp_rank = 0 + self.tp_group = None + self.pp_group = None + self.ep_rank = 0 + self.ep_group = None + self.etp_rank = 0 + self.etp_group = None + + # PEFT tracking + self._is_peft_format = False + self._adapter_name = 'default' + self._peft_target_modules: Set[str] = set() + self._peft_modules_to_save: Set[str] = set() + self._target_device = None + self._only_last_rank = False + + def _get_tp_split_dim(self, mg_key: Optional[str]) -> Optional[int]: + """Determine which dimension to split for tensor parallelism.""" + if mg_key is None: + return None + + # ColumnParallel (split output dim) + dim0_keys = { + 'word_embeddings', 'linear_qkv', 'output_layer', + 'linear_q_proj', 'linear_q_up_proj', 'linear_kv_up_proj', + 'eh_proj', # MTP + } + # RowParallel (split input dim) + dim1_keys = {'linear_proj', 'linear_fc2'} + + # Handle LoRA keys + if 'lora_A' not in mg_key and 'lora_B' not in mg_key: + key_parts = mg_key.rsplit('.', 2) + if len(key_parts) >= 2: + key = key_parts[-2] + suffix = key_parts[-1] + + if suffix == 'layer_norm_weight': + return None + elif key in dim0_keys: + return 0 + elif key in {'linear_fc1'} and suffix != 'bias': + return 1 + elif key in dim1_keys and suffix != 'bias': + return 1 + else: + # LoRA weights + key_parts = mg_key.rsplit('.', 3) + if len(key_parts) >= 2: + key = key_parts[0] + lora_name = key_parts[1] if len(key_parts) > 1 else '' + if lora_name == 'lora_A': + if key in dim1_keys: + return 1 + elif lora_name == 'lora_B': + if key in dim0_keys: + return 0 + elif key == 'linear_fc1': + return 1 + + return None + + def _split_tp(self, tensor: torch.Tensor, tp_dim: Optional[int], is_expert: bool = False) -> torch.Tensor: + """Split tensor for tensor parallelism.""" + tp_size = self.etp_size if is_expert else self.tp_size + tp_rank = self.etp_rank if is_expert else self.tp_rank + + if tp_dim is None or tp_size <= 1: + return tensor + return tensor.chunk(tp_size, dim=tp_dim)[tp_rank] + + def _all_gather_tp(self, tensor: Optional[torch.Tensor], tp_dim: Optional[int], + is_expert: bool = False) -> Optional[torch.Tensor]: + """All-gather tensor across tensor parallel group.""" + if tensor is None: + return None + + tensor = tensor.to('cuda') + tp_size = self.etp_size if is_expert else self.tp_size + tp_group = self.etp_group if is_expert else self.tp_group + + if tp_dim is None or tp_size <= 1: + return tensor + + if tp_dim == 0: + tensor_shape = list(tensor.shape) + tensor_shape[0] *= tp_size + output = tensor.new_empty(tensor_shape) + dist.all_gather_into_tensor(output, tensor, group=tp_group) + return output + else: + output = [torch.empty_like(tensor) for _ in range(tp_size)] + dist.all_gather(output, tensor, group=tp_group) + return torch.cat(output, dim=tp_dim) + + def _set_weight( + self, + mg_param: Union[torch.Tensor, nn.Parameter, List], + hf_weight: torch.Tensor, + mg_key: str, + is_expert: bool = False, + ): + """Set weight from HuggingFace to Megatron parameter.""" + tp_dim = self._get_tp_split_dim(mg_key) + tensor = self._split_tp(hf_weight, tp_dim, is_expert) + + if not isinstance(mg_param, (list, tuple)): + mg_param = [mg_param] + + tensor_list = tensor.chunk(len(mg_param), dim=0) + for i, param in enumerate(mg_param): + t = tensor_list[i].reshape(*param.shape) + param.data.copy_(t) + + def _get_weight( + self, + mg_weight: Optional[Union[torch.Tensor, List[torch.Tensor]]], + mg_key: Optional[str], + is_expert: bool = False, + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + """Get weight from Megatron parameter, gathered across TP.""" + if mg_weight is None: + return None, None + + tensor = mg_weight + if not isinstance(tensor, (list, tuple)): + tensor = [tensor] + + tensor = torch.cat(tensor, dim=0) + tp_dim = self._get_tp_split_dim(mg_key) + tensor = self._all_gather_tp(tensor, tp_dim, is_expert) + + if self._target_device is not None and tensor is not None: + tensor = tensor.to(device=self._target_device) + + if self._only_last_rank and not is_last_rank(): + return None, None + + return tensor, None + + # ========================================================================= + # Weight Loading Methods + # ========================================================================= + + def _load_embedding(self, mg_model, loader: SafetensorLoader): + """Load embedding weights.""" + embed_module = deep_getattr(mg_model, 'embedding.word_embeddings') + if embed_module is None: + return + + hf_weight = loader.get_tensor(self.HF_EMBED_KEY) + + # Pad vocabulary if needed + if hf_weight.shape[0] < self.config.padded_vocab_size: + hf_weight = F.pad( + hf_weight, + (0, 0, 0, self.config.padded_vocab_size - hf_weight.shape[0]) + ) + + self._set_weight(embed_module.weight, hf_weight, 'word_embeddings.weight') + + def _load_output_layer(self, mg_model, loader: SafetensorLoader): + """Load output layer (lm_head) weights.""" + output_module = deep_getattr(mg_model, 'output_layer') + if output_module is None or output_module.weight is None: + return + + # Check if weights are tied + if self.config.tie_word_embeddings: + hf_weight = loader.get_tensor(self.HF_EMBED_KEY) + else: + hf_weight = loader.get_tensor(self.HF_LM_HEAD_KEY) + + # Pad vocabulary if needed + if hf_weight.shape[0] < self.config.padded_vocab_size: + hf_weight = F.pad( + hf_weight, + (0, 0, 0, self.config.padded_vocab_size - hf_weight.shape[0]) + ) + + self._set_weight(output_module.weight, hf_weight, 'output_layer.weight') + + def _load_final_layernorm(self, mg_model, loader: SafetensorLoader): + """Load final layer norm weights.""" + ln_module = deep_getattr(mg_model, 'decoder.final_layernorm') + if ln_module is None: + return + + hf_weight = loader.get_tensor(self.HF_FINAL_LAYERNORM_KEY) + ln_module.weight.data.copy_(hf_weight) + + def _load_attention(self, mg_layer, loader: SafetensorLoader, layer_idx: int): + """Load attention layer weights.""" + mg_attn = mg_layer.self_attention + prefix = f'{self.HF_LAYERS_PREFIX}.{layer_idx}.self_attn.' + + num_heads = self.config.num_attention_heads + num_kv_heads = self.config.num_key_value_heads + hidden_size = self.config.hidden_size + head_dim = hidden_size // num_heads + heads_per_group = num_heads // num_kv_heads + + # Load Q, K, V weights and merge into linear_qkv + q_weight = loader.get_tensor(f'{prefix}q_proj.weight') + k_weight = loader.get_tensor(f'{prefix}k_proj.weight') + v_weight = loader.get_tensor(f'{prefix}v_proj.weight') + + # Reshape for GQA + q_weight = q_weight.reshape(num_kv_heads, heads_per_group * head_dim, hidden_size) + k_weight = k_weight.reshape(num_kv_heads, head_dim, hidden_size) + v_weight = v_weight.reshape(num_kv_heads, head_dim, hidden_size) + + qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=1) + qkv_weight = qkv_weight.reshape(-1, hidden_size) + + self._set_weight(mg_attn.linear_qkv.weight, qkv_weight, 'linear_qkv.weight') + + # Load O projection + o_weight = loader.get_tensor(f'{prefix}o_proj.weight') + self._set_weight(mg_attn.linear_proj.weight, o_weight, 'linear_proj.weight') + + # Load biases if present + if self.config.add_qkv_bias: + try: + q_bias = loader.get_tensor(f'{prefix}q_proj.bias') + k_bias = loader.get_tensor(f'{prefix}k_proj.bias') + v_bias = loader.get_tensor(f'{prefix}v_proj.bias') + + q_bias = q_bias.reshape(num_kv_heads, heads_per_group * head_dim) + k_bias = k_bias.reshape(num_kv_heads, head_dim) + v_bias = v_bias.reshape(num_kv_heads, head_dim) + + qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).reshape(-1) + self._set_weight(mg_attn.linear_qkv.bias, qkv_bias, 'linear_qkv.bias') + except KeyError: + pass + + # Load input layernorm (may be fused) + ln_key = f'{self.HF_LAYERS_PREFIX}.{layer_idx}.input_layernorm.weight' + ln_weight = loader.get_tensor(ln_key) + + ln_param = deep_getattr(mg_attn, 'linear_qkv.layer_norm_weight') + if ln_param is not None: + ln_param.data.copy_(ln_weight) + else: + ln_module = deep_getattr(mg_layer, 'input_layernorm') + if ln_module is not None: + ln_module.weight.data.copy_(ln_weight) + + # QK layernorm (Qwen3) + if self.config.qk_layernorm: + try: + q_norm = loader.get_tensor(f'{prefix}q_norm.weight') + k_norm = loader.get_tensor(f'{prefix}k_norm.weight') + q_ln = deep_getattr(mg_attn, 'q_layernorm') + k_ln = deep_getattr(mg_attn, 'k_layernorm') + if q_ln is not None: + q_ln.weight.data.copy_(q_norm) + if k_ln is not None: + k_ln.weight.data.copy_(k_norm) + except KeyError: + pass + + def _load_mlp(self, mg_layer, loader: SafetensorLoader, layer_idx: int): + """Load MLP layer weights.""" + mg_mlp = mg_layer.mlp + prefix = f'{self.HF_LAYERS_PREFIX}.{layer_idx}.mlp.' + + # Check if gate_up_proj is fused + try: + gate_weight = loader.get_tensor(f'{prefix}gate_proj.weight') + up_weight = loader.get_tensor(f'{prefix}up_proj.weight') + + # Stack gate and up projections (shape: [2, intermediate, hidden]) + fc1_weight = torch.stack([gate_weight, up_weight], dim=0) + self._set_weight(mg_mlp.linear_fc1.weight, fc1_weight, 'linear_fc1.weight') + except KeyError: + # Try gate_up_proj (fused) + try: + gate_up_weight = loader.get_tensor(f'{prefix}gate_up_proj.weight') + gate_up_weight = gate_up_weight.view(2, -1, gate_up_weight.shape[-1]) + self._set_weight(mg_mlp.linear_fc1.weight, gate_up_weight, 'linear_fc1.weight') + except KeyError: + pass + + # Load down projection + try: + down_weight = loader.get_tensor(f'{prefix}down_proj.weight') + self._set_weight(mg_mlp.linear_fc2.weight, down_weight, 'linear_fc2.weight') + except KeyError: + pass + + # Load post attention layernorm + ln_key = f'{self.HF_LAYERS_PREFIX}.{layer_idx}.post_attention_layernorm.weight' + try: + ln_weight = loader.get_tensor(ln_key) + + ln_param = deep_getattr(mg_mlp, 'linear_fc1.layer_norm_weight') + if ln_param is not None: + ln_param.data.copy_(ln_weight) + else: + ln_module = deep_getattr(mg_layer, 'pre_mlp_layernorm') + if ln_module is not None: + ln_module.weight.data.copy_(ln_weight) + except KeyError: + pass + + def _load_moe(self, mg_layer, loader: SafetensorLoader, layer_idx: int): + """Load MoE layer weights.""" + mg_mlp = mg_layer.mlp + prefix = f'{self.HF_LAYERS_PREFIX}.{layer_idx}.mlp.' + + # Load router + try: + router_key = None + for key in ['gate.weight', 'router.weight', 'gate.wg.weight']: + full_key = f'{prefix}{key}' + if full_key in loader: + router_key = full_key + break + + if router_key: + router_weight = loader.get_tensor(router_key) + router_module = deep_getattr(mg_mlp, 'router') + if router_module is not None and hasattr(router_module, 'weight'): + router_module.weight.data.copy_(router_weight) + except KeyError: + pass + + # Load shared experts if present + if self.config.shared_expert_intermediate_size > 0: + for shared_key in ['shared_expert', 'shared_experts', 'shared_mlp']: + try: + gate_weight = loader.get_tensor(f'{prefix}{shared_key}.gate_proj.weight') + up_weight = loader.get_tensor(f'{prefix}{shared_key}.up_proj.weight') + down_weight = loader.get_tensor(f'{prefix}{shared_key}.down_proj.weight') + + shared_module = deep_getattr(mg_mlp, 'shared_experts') + if shared_module is not None: + fc1_weight = torch.stack([gate_weight, up_weight], dim=0) + self._set_weight(shared_module.linear_fc1.weight, fc1_weight, 'linear_fc1.weight') + self._set_weight(shared_module.linear_fc2.weight, down_weight, 'linear_fc2.weight') + break + except KeyError: + continue + + # Load experts + num_local_experts = self.config.num_experts // self.ep_size + experts_module = deep_getattr(mg_mlp, 'experts') + + if experts_module is not None: + for local_idx in range(num_local_experts): + global_idx = self.ep_rank * num_local_experts + local_idx + + try: + gate_weight = loader.get_tensor(f'{prefix}experts.{global_idx}.gate_proj.weight') + up_weight = loader.get_tensor(f'{prefix}experts.{global_idx}.up_proj.weight') + down_weight = loader.get_tensor(f'{prefix}experts.{global_idx}.down_proj.weight') + + # For grouped linear, weights are stored differently + if hasattr(experts_module, 'linear_fc1'): + # TEGroupedLinear format + fc1_weight = torch.stack([gate_weight, up_weight], dim=0) + # Set individual expert weight + fc1_param = getattr(experts_module.linear_fc1, f'weight{local_idx}', None) + if fc1_param is not None: + self._set_weight(fc1_param, fc1_weight, 'linear_fc1.weight', is_expert=True) + + fc2_param = getattr(experts_module.linear_fc2, f'weight{local_idx}', None) + if fc2_param is not None: + self._set_weight(fc2_param, down_weight, 'linear_fc2.weight', is_expert=True) + elif hasattr(experts_module, '__getitem__'): + # List of experts + expert = experts_module[local_idx] + if hasattr(expert, 'linear_fc1'): + fc1_weight = torch.stack([gate_weight, up_weight], dim=0) + self._set_weight(expert.linear_fc1.weight, fc1_weight, 'linear_fc1.weight') + self._set_weight(expert.linear_fc2.weight, down_weight, 'linear_fc2.weight') + except KeyError: + continue + + # Load post attention layernorm + ln_key = f'{self.HF_LAYERS_PREFIX}.{layer_idx}.post_attention_layernorm.weight' + try: + ln_weight = loader.get_tensor(ln_key) + ln_param = deep_getattr(mg_mlp, 'linear_fc1.layer_norm_weight') + if ln_param is not None: + ln_param.data.copy_(ln_weight) + except KeyError: + pass + + def _load_layer(self, mg_layer, loader: SafetensorLoader, layer_idx: int): + """Load a single transformer layer.""" + self._load_attention(mg_layer, loader, layer_idx) + + # Check if MoE layer + if self.config.num_experts > 0: + self._load_moe(mg_layer, loader, layer_idx) + else: + self._load_mlp(mg_layer, loader, layer_idx) + + def load_weights( + self, + mg_model: nn.Module, + model_path: str, + is_peft_format: bool = False, + adapter_name: str = 'default', + ) -> None: + """Load HuggingFace weights into Megatron model. + + Args: + mg_model: Megatron GPT model. + model_path: Path to HuggingFace checkpoint. + is_peft_format: Whether loading PEFT adapter weights. + adapter_name: Name of the adapter for PEFT. + """ + self._is_peft_format = is_peft_format + self._adapter_name = adapter_name + + with torch.no_grad(): + with SafetensorLoader(model_path, is_peft_format=is_peft_format) as loader: + if is_peft_format: + self._load_peft_weights(mg_model, loader) + else: + self._load_base_weights(mg_model, loader) + + def _load_base_weights(self, mg_model: nn.Module, loader: SafetensorLoader): + """Load base model weights.""" + # Get decoder + decoder = deep_getattr(mg_model, 'decoder') + if decoder is None: + decoder = mg_model + + layers = getattr(decoder, 'layers', []) + + # Load pre-process (embedding) on first PP rank + if self.pp_size <= 1 or self.pp_rank == 0: + try: + self._load_embedding(mg_model, loader) + except Exception as e: + print(f"Warning: Failed to load embedding: {e}") + + # Load transformer layers + prog_bar = tqdm( + layers, + desc='Loading weights', + disable=self.disable_tqdm + ) + for mg_layer in prog_bar: + layer_idx = mg_layer.layer_number - 1 # 1-indexed to 0-indexed + try: + self._load_layer(mg_layer, loader, layer_idx) + except Exception as e: + print(f"Warning: Failed to load layer {layer_idx}: {e}") + + # Load post-process on last PP rank + if self.pp_size <= 1 or self.pp_rank == self.pp_size - 1: + try: + self._load_final_layernorm(mg_model, loader) + self._load_output_layer(mg_model, loader) + except Exception as e: + print(f"Warning: Failed to load post-process: {e}") + + def _load_peft_weights(self, mg_model: nn.Module, loader: SafetensorLoader): + """Load PEFT/LoRA adapter weights.""" + state_dict = loader.get_state_dict() + hf_prefix = 'base_model.model.' if self._is_peft_format else '' + + # Build mapping from HF keys to Megatron keys + for key, lazy_tensor in state_dict.items(): + # Remove base_model.model. prefix + if key.startswith(hf_prefix): + key = key[len(hf_prefix):] + + # Parse the key to find target module + if '.lora_A.' in key or '.lora_B.' in key: + tensor = lazy_tensor.load() + self._load_peft_tensor(mg_model, key, tensor) + + def _load_peft_tensor(self, mg_model: nn.Module, key: str, tensor: torch.Tensor): + """Load a single PEFT tensor into the model.""" + # Parse key: model.layers.0.self_attn.q_proj.lora_A.weight + parts = key.split('.') + + # Find layer index + layer_idx = None + for i, p in enumerate(parts): + if p == 'layers' and i + 1 < len(parts): + layer_idx = int(parts[i + 1]) + break + + if layer_idx is None: + return + + # Get layer + decoder = deep_getattr(mg_model, 'decoder') + if decoder is None: + decoder = mg_model + + layers = getattr(decoder, 'layers', []) + for layer in layers: + if layer.layer_number - 1 == layer_idx: + mg_layer = layer + break + else: + return + + # Determine target and lora type + is_lora_A = '.lora_A.' in key + is_lora_B = '.lora_B.' in key + + if 'self_attn' in key: + mg_attn = mg_layer.self_attention + if 'q_proj' in key or 'k_proj' in key or 'v_proj' in key: + target = deep_getattr(mg_attn, 'linear_qkv') + elif 'o_proj' in key: + target = deep_getattr(mg_attn, 'linear_proj') + else: + return + elif 'mlp' in key: + mg_mlp = mg_layer.mlp + if 'gate_proj' in key or 'up_proj' in key: + target = deep_getattr(mg_mlp, 'linear_fc1') + elif 'down_proj' in key: + target = deep_getattr(mg_mlp, 'linear_fc2') + else: + return + else: + return + + if target is None: + return + + # Get LoRA module + if is_lora_A: + lora_module = deep_getattr(target, f'lora_A.{self._adapter_name}') + else: + lora_module = deep_getattr(target, f'lora_B.{self._adapter_name}') + + if lora_module is not None and hasattr(lora_module, 'weight'): + lora_module.weight.data.copy_(tensor) + + # ========================================================================= + # Weight Saving Methods + # ========================================================================= + + def export_weights( + self, + mg_models: Union[nn.Module, List[nn.Module]], + target_device: Optional[str] = None, + only_last_rank: bool = False, + is_peft_format: bool = False, + tqdm_desc: str = 'Exporting: ', + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Export weights from Megatron model to HuggingFace format. + + Yields: + Tuples of (key, tensor) for each weight. + """ + self._target_device = target_device + self._only_last_rank = only_last_rank + self._is_peft_format = is_peft_format + self._adapter_name = 'default' + self._peft_target_modules = set() + self._peft_modules_to_save = set() + + if not isinstance(mg_models, (list, tuple)): + mg_models = [mg_models] + + hf_prefix = 'base_model.model.' if is_peft_format else '' + + with torch.no_grad(): + # For now, handle single model + mg_model = mg_models[0] + + decoder = deep_getattr(mg_model, 'decoder') + if decoder is None: + decoder = mg_model + + layers = getattr(decoder, 'layers', []) + + if not is_peft_format: + # Export embedding + if self.pp_size <= 1 or self.pp_rank == 0: + embed = deep_getattr(mg_model, 'embedding.word_embeddings.weight') + if embed is not None: + weight, _ = self._get_weight(embed.data, 'word_embeddings.weight') + if weight is not None: + weight = weight[:self.config.vocab_size] + yield f'{hf_prefix}{self.HF_EMBED_KEY}', weight + + # Export layers + prog_bar = tqdm(layers, desc=tqdm_desc, disable=self.disable_tqdm) + for mg_layer in prog_bar: + layer_idx = mg_layer.layer_number - 1 + yield from self._export_layer(mg_layer, layer_idx, hf_prefix, is_peft_format) + + if not is_peft_format: + # Export final layernorm and output layer + if self.pp_size <= 1 or self.pp_rank == self.pp_size - 1: + ln_module = deep_getattr(mg_model, 'decoder.final_layernorm') + if ln_module is not None: + yield f'{hf_prefix}{self.HF_FINAL_LAYERNORM_KEY}', ln_module.weight.data.clone() + + output = deep_getattr(mg_model, 'output_layer.weight') + if output is not None: + weight, _ = self._get_weight(output.data, 'output_layer.weight') + if weight is not None: + weight = weight[:self.config.vocab_size] + yield f'{hf_prefix}{self.HF_LM_HEAD_KEY}', weight + + def _export_layer( + self, + mg_layer, + layer_idx: int, + hf_prefix: str, + is_peft_format: bool, + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Export a single layer.""" + prefix = f'{hf_prefix}{self.HF_LAYERS_PREFIX}.{layer_idx}.' + + mg_attn = mg_layer.self_attention + mg_mlp = mg_layer.mlp + + num_heads = self.config.num_attention_heads + num_kv_heads = self.config.num_key_value_heads + hidden_size = self.config.hidden_size + head_dim = hidden_size // num_heads + heads_per_group = num_heads // num_kv_heads + q_dim = heads_per_group * head_dim + kv_dim = head_dim + + if not is_peft_format: + # Export QKV + qkv_weight, _ = self._get_weight(mg_attn.linear_qkv.weight.data, 'linear_qkv.weight') + if qkv_weight is not None: + qkv_weight = qkv_weight.reshape(num_kv_heads, -1, hidden_size) + yield f'{prefix}self_attn.q_proj.weight', qkv_weight[:, :q_dim, :].reshape(-1, hidden_size).clone() + yield f'{prefix}self_attn.k_proj.weight', qkv_weight[:, q_dim:q_dim+kv_dim, :].reshape(-1, hidden_size).clone() + yield f'{prefix}self_attn.v_proj.weight', qkv_weight[:, -kv_dim:, :].reshape(-1, hidden_size).clone() + + # Export O + o_weight, _ = self._get_weight(mg_attn.linear_proj.weight.data, 'linear_proj.weight') + if o_weight is not None: + yield f'{prefix}self_attn.o_proj.weight', o_weight + + # Export layernorms + ln = deep_getattr(mg_attn, 'linear_qkv.layer_norm_weight') + if ln is not None: + yield f'{prefix}input_layernorm.weight', ln.data.clone() + + # Export MLP + fc1_weight, _ = self._get_weight(mg_mlp.linear_fc1.weight.data, 'linear_fc1.weight') + if fc1_weight is not None: + fc1_weight = fc1_weight.view(2, -1, hidden_size) + yield f'{prefix}mlp.gate_proj.weight', fc1_weight[0].clone() + yield f'{prefix}mlp.up_proj.weight', fc1_weight[1].clone() + + fc2_weight, _ = self._get_weight(mg_mlp.linear_fc2.weight.data, 'linear_fc2.weight') + if fc2_weight is not None: + yield f'{prefix}mlp.down_proj.weight', fc2_weight + + ln2 = deep_getattr(mg_mlp, 'linear_fc1.layer_norm_weight') + if ln2 is not None: + yield f'{prefix}post_attention_layernorm.weight', ln2.data.clone() + else: + # Export LoRA weights only + yield from self._export_lora_layer(mg_attn, mg_mlp, prefix) + + def _export_lora_layer( + self, + mg_attn, + mg_mlp, + prefix: str, + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Export LoRA weights from a layer.""" + # Check if LoRA is applied + from twinkle.megatron.tuners import LoraParallelLinear + + # Attention LoRA + if isinstance(mg_attn.linear_qkv, LoraParallelLinear): + lora_A = deep_getattr(mg_attn.linear_qkv, f'lora_A.{self._adapter_name}.weight') + lora_B = deep_getattr(mg_attn.linear_qkv, f'lora_B.{self._adapter_name}.weight') + + if lora_A is not None and lora_B is not None: + lora_A, _ = self._get_weight(lora_A.data, 'linear_qkv.lora_A.weight') + lora_B, _ = self._get_weight(lora_B.data, 'linear_qkv.lora_B.weight') + + if lora_A is not None: + self._peft_target_modules.update({'q_proj', 'k_proj', 'v_proj'}) + # Split lora_B for Q, K, V + for key in ['q_proj', 'k_proj', 'v_proj']: + yield f'{prefix}self_attn.{key}.lora_A.weight', lora_A.clone() + + num_kv_heads = self.config.num_key_value_heads + head_dim = self.config.hidden_size // self.config.num_attention_heads + heads_per_group = self.config.num_attention_heads // num_kv_heads + q_dim = heads_per_group * head_dim + + lora_B = lora_B.reshape(num_kv_heads, -1, lora_B.shape[-1]) + yield f'{prefix}self_attn.q_proj.lora_B.weight', lora_B[:, :q_dim, :].reshape(-1, lora_B.shape[-1]).clone() + yield f'{prefix}self_attn.k_proj.lora_B.weight', lora_B[:, q_dim:-head_dim, :].reshape(-1, lora_B.shape[-1]).clone() + yield f'{prefix}self_attn.v_proj.lora_B.weight', lora_B[:, -head_dim:, :].reshape(-1, lora_B.shape[-1]).clone() + + # O projection LoRA + if isinstance(mg_attn.linear_proj, LoraParallelLinear): + lora_A = deep_getattr(mg_attn.linear_proj, f'lora_A.{self._adapter_name}.weight') + lora_B = deep_getattr(mg_attn.linear_proj, f'lora_B.{self._adapter_name}.weight') + + if lora_A is not None and lora_B is not None: + lora_A, _ = self._get_weight(lora_A.data, 'linear_proj.lora_A.weight') + lora_B, _ = self._get_weight(lora_B.data, 'linear_proj.lora_B.weight') + + if lora_A is not None: + self._peft_target_modules.add('o_proj') + yield f'{prefix}self_attn.o_proj.lora_A.weight', lora_A.clone() + yield f'{prefix}self_attn.o_proj.lora_B.weight', lora_B.clone() + + # MLP LoRA + if hasattr(mg_mlp, 'linear_fc1') and isinstance(mg_mlp.linear_fc1, LoraParallelLinear): + lora_A = deep_getattr(mg_mlp.linear_fc1, f'lora_A.{self._adapter_name}.weight') + lora_B = deep_getattr(mg_mlp.linear_fc1, f'lora_B.{self._adapter_name}.weight') + + if lora_A is not None and lora_B is not None: + lora_A, _ = self._get_weight(lora_A.data, 'linear_fc1.lora_A.weight') + lora_B, _ = self._get_weight(lora_B.data, 'linear_fc1.lora_B.weight') + + if lora_A is not None: + self._peft_target_modules.update({'gate_proj', 'up_proj'}) + for key in ['gate_proj', 'up_proj']: + yield f'{prefix}mlp.{key}.lora_A.weight', lora_A.clone() + + lora_B = lora_B.reshape(2, -1, lora_B.shape[-1]) + yield f'{prefix}mlp.gate_proj.lora_B.weight', lora_B[0].clone() + yield f'{prefix}mlp.up_proj.lora_B.weight', lora_B[1].clone() + + if hasattr(mg_mlp, 'linear_fc2') and isinstance(mg_mlp.linear_fc2, LoraParallelLinear): + lora_A = deep_getattr(mg_mlp.linear_fc2, f'lora_A.{self._adapter_name}.weight') + lora_B = deep_getattr(mg_mlp.linear_fc2, f'lora_B.{self._adapter_name}.weight') + + if lora_A is not None and lora_B is not None: + lora_A, _ = self._get_weight(lora_A.data, 'linear_fc2.lora_A.weight') + lora_B, _ = self._get_weight(lora_B.data, 'linear_fc2.lora_B.weight') + + if lora_A is not None: + self._peft_target_modules.add('down_proj') + yield f'{prefix}mlp.down_proj.lora_A.weight', lora_A.clone() + yield f'{prefix}mlp.down_proj.lora_B.weight', lora_B.clone() + + def save_weights( + self, + mg_models: Union[nn.Module, List[nn.Module]], + output_dir: str, + is_peft_format: bool = False, + ) -> None: + """Save Megatron model weights in HuggingFace format. + + Args: + mg_models: Megatron model(s) to save. + output_dir: Directory to save weights. + is_peft_format: Whether saving in PEFT format. + """ + torch.cuda.empty_cache() + + saver = StreamingSafetensorSaver( + save_dir=output_dir, + max_shard_size=self.config.max_shard_size, + is_peft_format=is_peft_format, + ) + + for key, tensor in self.export_weights( + mg_models, + target_device='cpu', + only_last_rank=True, + is_peft_format=is_peft_format, + tqdm_desc='Saving: ', + ): + saver.add_tensor(key, tensor) + + saver.finalize() + + # Save config on last rank + if is_last_rank(): + if is_peft_format and not isinstance(mg_models, (list, tuple)): + mg_models = [mg_models] + + if is_peft_format and hasattr(mg_models[0], 'peft_config'): + peft_config = copy(mg_models[0].peft_config.get(self._adapter_name)) + if peft_config is not None: + peft_config.target_modules = list(self._peft_target_modules) + peft_config.modules_to_save = list(self._peft_modules_to_save) + peft_config.save_pretrained(output_dir) + elif not is_peft_format and self.hf_config is not None: + # Save HF config + self.hf_config.vocab_size = self.config.padded_vocab_size + self.hf_config.save_pretrained(output_dir) + + if dist.is_initialized(): + dist.barrier() + + +class TwinkleBridgeAdapter: + """Adapter for weight loading using TwinkleGPTBridge. + + Provides a simple interface for loading HF weights into Megatron models. + """ + + def __init__( + self, + hf_config: Any, + tp_size: int = 1, + pp_size: int = 1, + ep_size: int = 1, + etp_size: Optional[int] = None, + model_path: Optional[str] = None, + padded_vocab_size: Optional[int] = None, + **kwargs, + ): + """Initialize the bridge adapter.""" + self.hf_config = hf_config + self.model_path = model_path + + # Create bridge config + self.config = BridgeConfig.from_hf_config( + hf_config=hf_config, + tp_size=tp_size, + pp_size=pp_size, + ep_size=ep_size, + padded_vocab_size=padded_vocab_size, + ) + if etp_size is not None: + self.config.etp_size = etp_size + + self._bridge = None + + def _get_bridge(self) -> TwinkleGPTBridge: + """Get or create the bridge instance.""" + if self._bridge is None: + self._bridge = TwinkleGPTBridge( + config=self.config, + hf_config=self.hf_config, + ) + return self._bridge + + def load_weights( + self, + mg_model: nn.Module, + model_path: Optional[str] = None, + is_peft_format: bool = False, + adapter_name: str = 'default', + ) -> None: + """Load HuggingFace weights into Megatron model.""" + model_path = model_path or self.model_path + if model_path is None: + raise ValueError("model_path must be provided") + + bridge = self._get_bridge() + bridge.load_weights(mg_model, model_path, is_peft_format, adapter_name) + + def save_weights( + self, + mg_models: Union[nn.Module, List[nn.Module]], + output_dir: str, + is_peft_format: bool = False, + ) -> None: + """Save Megatron model weights in HuggingFace format.""" + bridge = self._get_bridge() + bridge.save_weights(mg_models, output_dir, is_peft_format) + + +# Legacy functions for backward compatibility +def create_megatron_args(*args, **kwargs) -> SimpleNamespace: + """Legacy function - use BridgeConfig instead.""" + return SimpleNamespace(**kwargs) + + +def set_megatron_args(args: SimpleNamespace) -> None: + """Legacy function - no longer needed with TwinkleGPTBridge.""" + pass + + +def restore_megatron_args() -> None: + """Legacy function - no longer needed with TwinkleGPTBridge.""" + pass + + +def mock_megatron_args(args: SimpleNamespace): + """Legacy function - no longer needed with TwinkleGPTBridge.""" + from contextlib import contextmanager + @contextmanager + def noop(): + yield args + return noop() + + +def load_hf_weights_to_megatron( + mg_model: nn.Module, + model_path: str, + hf_config: Any, + tp_size: int = 1, + pp_size: int = 1, + ep_size: int = 1, + padded_vocab_size: Optional[int] = None, +) -> None: + """Convenience function to load HF weights into Megatron model.""" + adapter = TwinkleBridgeAdapter( + hf_config=hf_config, + tp_size=tp_size, + pp_size=pp_size, + ep_size=ep_size, + model_path=model_path, + padded_vocab_size=padded_vocab_size, + ) + adapter.load_weights(mg_model, model_path) diff --git a/src/twinkle/megatron/model/initializer.py b/src/twinkle/megatron/model/initializer.py new file mode 100644 index 00000000..ae0b5346 --- /dev/null +++ b/src/twinkle/megatron/model/initializer.py @@ -0,0 +1,325 @@ +# Copyright (c) twinkle authors. All rights reserved. +"""Megatron model initialization from HuggingFace checkpoints.""" +from dataclasses import fields +from typing import Any, Dict, Optional, Type + +import torch +import torch.nn as nn +import torch.distributed as dist + +# Direct imports - assume megatron is installed +import megatron.core +from megatron.core import parallel_state as mpu +from megatron.core.transformer import TransformerConfig +from megatron.core.models.gpt import GPTModel +from packaging import version + +mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') + +from ..utils import convert_hf_config + + +def _get_transformer_config_fields() -> set: + """Get valid field names for TransformerConfig. + + Returns: + Set of valid field names. + """ + return {f.name for f in fields(TransformerConfig)} + + +class MegatronModelInitializer: + """Initialize Megatron-Core models from HuggingFace checkpoints. + + This class handles: + - Converting HuggingFace config to Megatron TransformerConfig + - Creating Megatron model architecture + - Loading HuggingFace weights into Megatron model + """ + + def __init__( + self, + tp_size: int = 1, + pp_size: int = 1, + cp_size: int = 1, + ep_size: int = 1, + etp_size: Optional[int] = None, + vp_size: Optional[int] = None, + sequence_parallel: bool = False, + params_dtype: torch.dtype = torch.bfloat16, + use_cpu_initialization: bool = True, + ): + """Initialize MegatronModelInitializer. + + Args: + tp_size: Tensor parallel size. + pp_size: Pipeline parallel size. + cp_size: Context parallel size. + ep_size: Expert parallel size. + etp_size: Expert tensor parallel size (defaults to tp_size). + vp_size: Virtual pipeline parallel size. + sequence_parallel: Enable sequence parallelism. + params_dtype: Parameter data type. + use_cpu_initialization: Initialize model on CPU first. + """ + self.tp_size = tp_size + self.pp_size = pp_size + self.cp_size = cp_size + self.ep_size = ep_size + self.etp_size = etp_size or tp_size + self.vp_size = vp_size + self.sequence_parallel = sequence_parallel + self.params_dtype = params_dtype + self.use_cpu_initialization = use_cpu_initialization + + # Cache valid TransformerConfig fields + self._valid_config_fields = _get_transformer_config_fields() + + def create_transformer_config( + self, + hf_config: Any, + **overrides, + ) -> 'TransformerConfig': + """Create Megatron TransformerConfig from HuggingFace config. + + Args: + hf_config: HuggingFace model config. + **overrides: Config overrides. + + Returns: + Megatron TransformerConfig. + """ + # Convert HuggingFace config to dict + mg_config_dict = convert_hf_config(hf_config) + + # Apply overrides + mg_config_dict.update(overrides) + + # Build config kwargs with only valid fields + config_kwargs = { + # Required fields + 'num_layers': mg_config_dict['num_layers'], + 'hidden_size': mg_config_dict['hidden_size'], + 'num_attention_heads': mg_config_dict['num_attention_heads'], + # Parallel settings + 'tensor_model_parallel_size': self.tp_size, + 'pipeline_model_parallel_size': self.pp_size, + 'context_parallel_size': self.cp_size, + 'expert_model_parallel_size': self.ep_size, + 'sequence_parallel': self.sequence_parallel, + 'params_dtype': self.params_dtype, + 'use_cpu_initialization': self.use_cpu_initialization, + } + + # Optional fields - only add if valid for this Megatron version + optional_fields = { + 'num_query_groups': mg_config_dict.get('num_query_groups', mg_config_dict['num_attention_heads']), + 'ffn_hidden_size': mg_config_dict.get('ffn_hidden_size', 4 * mg_config_dict['hidden_size']), + 'num_moe_experts': mg_config_dict.get('num_experts'), + 'moe_router_topk': mg_config_dict.get('moe_router_topk', 2) if mg_config_dict.get('num_experts') else None, + 'layernorm_epsilon': mg_config_dict.get('norm_epsilon', 1e-6), + 'add_qkv_bias': mg_config_dict.get('add_qkv_bias', False), + 'add_bias_linear': not mg_config_dict.get('disable_bias_linear', True), + 'gated_linear_unit': mg_config_dict.get('swiglu', True), + 'qk_layernorm': mg_config_dict.get('qk_layernorm', False), + 'normalization': 'RMSNorm', + } + + # Add optional fields that are valid for this Megatron version + for key, value in optional_fields.items(): + if key in self._valid_config_fields and value is not None: + config_kwargs[key] = value + + # Store rotary settings for GPTModel (not TransformerConfig) + self._rotary_base = mg_config_dict.get('rotary_base', 10000) + self._rotary_percent = mg_config_dict.get('rotary_percent', 1.0) + self._position_embedding_type = mg_config_dict.get('position_embedding_type', 'rope') + + # Create TransformerConfig + config = TransformerConfig(**config_kwargs) + + return config + + def create_gpt_model( + self, + hf_config: Any, + vocab_size: Optional[int] = None, + max_sequence_length: Optional[int] = None, + **config_overrides, + ) -> 'GPTModel': + """Create Megatron GPT model from HuggingFace config. + + Args: + hf_config: HuggingFace model config. + vocab_size: Override vocab size. + max_sequence_length: Override max sequence length. + **config_overrides: Config overrides. + + Returns: + Megatron GPTModel. + """ + # Create config (also sets self._rotary_base, etc.) + config = self.create_transformer_config(hf_config, **config_overrides) + + # Get vocab size + if vocab_size is None: + vocab_size = hf_config.vocab_size + + # Pad vocab size for tensor parallelism + padded_vocab_size = self._pad_vocab_size(vocab_size) + + # Get max sequence length + if max_sequence_length is None: + max_sequence_length = getattr(hf_config, 'max_position_embeddings', 4096) + + # Get tie_word_embeddings setting + tie_word_embeddings = getattr(hf_config, 'tie_word_embeddings', False) + + # Create model with rotary settings passed directly to GPTModel + model = GPTModel( + config=config, + transformer_layer_spec=self._get_layer_spec(config), + vocab_size=padded_vocab_size, + max_sequence_length=max_sequence_length, + pre_process=mpu.is_pipeline_first_stage(), + post_process=mpu.is_pipeline_last_stage(), + parallel_output=True, + share_embeddings_and_output_weights=tie_word_embeddings, + position_embedding_type=self._position_embedding_type, + rotary_percent=self._rotary_percent, + rotary_base=self._rotary_base, + ) + + return model + + def _pad_vocab_size(self, vocab_size: int) -> int: + """Pad vocab size for tensor parallelism. + + Args: + vocab_size: Original vocab size. + + Returns: + Padded vocab size. + """ + # Pad to multiple of tp_size * 128 for efficient parallelism + divisor = self.tp_size * 128 + return ((vocab_size + divisor - 1) // divisor) * divisor + + def _get_layer_spec(self, config: 'TransformerConfig'): + """Get transformer layer specification. + + Args: + config: Transformer config. + + Returns: + Layer specification (ModuleSpec or TransformerBlockSubmodules). + """ + from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_layer_with_transformer_engine_spec, + get_gpt_layer_local_spec, + ) + + # Determine if this is a MoE model + num_experts = getattr(config, 'num_moe_experts', None) + moe_grouped_gemm = getattr(config, 'moe_grouped_gemm', False) + qk_layernorm = getattr(config, 'qk_layernorm', False) + multi_latent_attention = getattr(config, 'multi_latent_attention', False) + + # Try TE (TransformerEngine) layers first for better performance + try: + return get_gpt_layer_with_transformer_engine_spec( + num_experts=num_experts, + moe_grouped_gemm=moe_grouped_gemm, + qk_layernorm=qk_layernorm, + multi_latent_attention=multi_latent_attention, + ) + except Exception: + # Fallback to local spec without TE + return get_gpt_layer_local_spec( + num_experts=num_experts, + moe_grouped_gemm=moe_grouped_gemm, + qk_layernorm=qk_layernorm, + multi_latent_attention=multi_latent_attention, + ) + + def load_from_hf( + self, + model: nn.Module, + hf_model_path: str, + hf_config: Any, + ) -> None: + """Load HuggingFace checkpoint into Megatron model. + + Uses swift's GPTBridge for maximum compatibility and stability. + + Args: + model: The Megatron model. + hf_model_path: Path to HuggingFace checkpoint. + hf_config: HuggingFace model config. + """ + from .bridge import TwinkleBridgeAdapter + + # Calculate padded vocab size + padded_vocab_size = self._pad_vocab_size(hf_config.vocab_size) + + # Create bridge adapter + adapter = TwinkleBridgeAdapter( + hf_config=hf_config, + tp_size=self.tp_size, + pp_size=self.pp_size, + ep_size=self.ep_size, + etp_size=self.etp_size, + model_path=hf_model_path, + padded_vocab_size=padded_vocab_size, + ) + + # Load weights using swift's bridge + adapter.load_weights(model, hf_model_path) + + +def initialize_megatron_model( + hf_model_path: str, + tp_size: int = 1, + pp_size: int = 1, + cp_size: int = 1, + ep_size: int = 1, + params_dtype: torch.dtype = torch.bfloat16, + load_weights: bool = True, +) -> nn.Module: + """Convenience function to initialize Megatron model from HuggingFace checkpoint. + + Args: + hf_model_path: Path to HuggingFace checkpoint. + tp_size: Tensor parallel size. + pp_size: Pipeline parallel size. + cp_size: Context parallel size. + ep_size: Expert parallel size. + params_dtype: Parameter data type. + load_weights: Whether to load weights. + + Returns: + Initialized Megatron model. + """ + from transformers import AutoConfig + + # Load HuggingFace config + hf_config = AutoConfig.from_pretrained(hf_model_path) + + # Create initializer + initializer = MegatronModelInitializer( + tp_size=tp_size, + pp_size=pp_size, + cp_size=cp_size, + ep_size=ep_size, + params_dtype=params_dtype, + ) + + # Create model + model = initializer.create_gpt_model(hf_config) + + # Load weights + if load_weights: + initializer.load_from_hf(model, hf_model_path, hf_config) + + return model + diff --git a/src/twinkle/megatron/model/qwen3.py b/src/twinkle/megatron/model/qwen3.py new file mode 100644 index 00000000..d49c4c92 --- /dev/null +++ b/src/twinkle/megatron/model/qwen3.py @@ -0,0 +1,64 @@ +# Copyright (c) twinkle authors. All rights reserved. +"""Qwen3 model metadata for Megatron-Core. + +This module provides metadata for Qwen3 models. The actual weight conversion +is handled by swift's GPTBridge, which already has full Qwen3 support. +""" +from typing import Any, Dict + + +# ============================================================================= +# Qwen3 Model Metadata +# ============================================================================= +class Qwen3ModelMeta: + """Metadata for Qwen3 models.""" + + # Supported architectures + DENSE_ARCHITECTURES = ['Qwen3ForCausalLM', 'Qwen2ForCausalLM', 'Qwen2.5ForCausalLM'] + MOE_ARCHITECTURES = ['Qwen3MoeForCausalLM', 'Qwen2MoeForCausalLM'] + ALL_ARCHITECTURES = DENSE_ARCHITECTURES + MOE_ARCHITECTURES + + # HuggingFace key prefixes + HF_LAYERS_PREFIX = 'model.layers' + HF_EMBED_KEY = 'model.embed_tokens.weight' + HF_FINAL_LAYERNORM_KEY = 'model.norm.weight' + HF_LM_HEAD_KEY = 'lm_head.weight' + + # Qwen3 specific settings + DEFAULT_CONFIG = { + 'qk_layernorm': True, + 'swiglu': True, + 'disable_bias_linear': True, + 'rotary_interleaved': False, + } + + # MoE specific settings + MOE_CONFIG = { + 'use_shared_expert_gate': True, + } + + @classmethod + def is_qwen3(cls, architecture: str) -> bool: + """Check if architecture is a Qwen3 model.""" + return architecture in cls.ALL_ARCHITECTURES + + @classmethod + def is_qwen3_moe(cls, architecture: str) -> bool: + """Check if architecture is a Qwen3 MoE model.""" + return architecture in cls.MOE_ARCHITECTURES + + +def get_model_default_config(architecture: str) -> Dict[str, Any]: + """Get default config overrides for a model architecture. + + Args: + architecture: Model architecture name. + + Returns: + Default config dict for Megatron TransformerConfig. + """ + if Qwen3ModelMeta.is_qwen3_moe(architecture): + return {**Qwen3ModelMeta.DEFAULT_CONFIG, **Qwen3ModelMeta.MOE_CONFIG} + elif Qwen3ModelMeta.is_qwen3(architecture): + return Qwen3ModelMeta.DEFAULT_CONFIG + return {} diff --git a/src/twinkle/megatron/tuners/__init__.py b/src/twinkle/megatron/tuners/__init__.py new file mode 100644 index 00000000..c6ea530f --- /dev/null +++ b/src/twinkle/megatron/tuners/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) twinkle authors. All rights reserved. +"""Megatron-compatible tuners for efficient fine-tuning.""" + +from .lora import LoraParallelLinear, dispatch_megatron + +__all__ = [ + 'LoraParallelLinear', + 'dispatch_megatron', +] diff --git a/src/twinkle/megatron/tuners/lora.py b/src/twinkle/megatron/tuners/lora.py new file mode 100644 index 00000000..17714e0d --- /dev/null +++ b/src/twinkle/megatron/tuners/lora.py @@ -0,0 +1,606 @@ +# Copyright (c) twinkle authors. All rights reserved. +# Code adapted from huggingface/peft and ms-swift +# [SWIFT] Core LoRA implementation adapted from swift's megatron tuners. +"""Megatron-compatible LoRA implementation with Tensor Parallel support.""" +import math +import warnings +from contextlib import contextmanager +from typing import Any, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +# Direct imports - assume megatron and peft are installed +import megatron.core +from megatron.core import parallel_state +from megatron.core.dist_checkpointing.mapping import ShardedStateDict +from megatron.core.extensions.transformer_engine import ( + TEColumnParallelGroupedLinear, TEColumnParallelLinear, + TEGroupedLinear, TELayerNormColumnParallelLinear, TELinear, + TERowParallelGroupedLinear, TERowParallelLinear +) +from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding +from megatron.core.parallel_state import ( + get_expert_tensor_parallel_world_size, + get_tensor_model_parallel_world_size +) +from megatron.core.tensor_parallel import ( + gather_from_sequence_parallel_region, + scatter_to_sequence_parallel_region +) +from megatron.core.transformer.mlp import apply_swiglu_sharded_factory +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.moe.router import TopKRouter +from packaging import version + +from peft.tuners.lora import model +from peft.tuners.lora.layer import LoraLayer +from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge +from peft.utils.other import transpose + +mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') + + +class LoraParallelLinear(MegatronModule, LoraLayer): + """LoRA layer compatible with Megatron Tensor Parallel Linear layers. + + This class wraps Megatron's parallel linear layers (TELinear, TEColumnParallelLinear, + TERowParallelLinear, etc.) and adds LoRA adapters that are correctly sharded + across tensor parallel ranks. + """ + + def __init__( + self, + base_layer, + adapter_name: str, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + fan_in_fan_out: bool = False, + init_lora_weights: bool = True, + use_rslora: bool = False, + use_dora: bool = False, + lora_bias: bool = False, + **kwargs, + ): + """Initialize LoraParallelLinear. + + Args: + base_layer: The Megatron parallel linear layer to wrap. + adapter_name: Name of the LoRA adapter. + r: LoRA rank. + lora_alpha: LoRA alpha scaling factor. + lora_dropout: Dropout probability for LoRA layers. + fan_in_fan_out: Whether the layer uses fan-in/fan-out convention. + init_lora_weights: Whether to initialize LoRA weights. + use_rslora: Use rank-stabilized LoRA scaling. + use_dora: Use DoRA (not supported yet). + lora_bias: Whether to add bias to LoRA layers. + """ + config = base_layer.config + super().__init__(config=config) + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + LoraLayer.__init__(self, base_layer=base_layer) + + if use_dora: + raise ValueError(f'{self.__class__.__name__} does not support DoRA yet, please set it to False') + + self.is_parallel_a = isinstance(base_layer, (TERowParallelLinear, TERowParallelGroupedLinear)) + self.is_grouped = isinstance(base_layer, TEGroupedLinear) + self.fan_in_fan_out = fan_in_fan_out + self._active_adapter = adapter_name + self.is_expert = getattr(base_layer, 'is_expert', False) + self.sequence_parallel = getattr(base_layer, 'sequence_parallel', False) + + if self.is_expert: + self.tp_size = get_expert_tensor_parallel_world_size() + else: + self.tp_size = get_tensor_model_parallel_world_size() + + self.update_layer( + adapter_name, + r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + init_lora_weights=init_lora_weights, + use_rslora=use_rslora, + lora_bias=lora_bias, + ) + + self.is_target_conv_1d_layer = False + + def update_layer( + self, + adapter_name: str, + r: int, + *, + lora_alpha: int, + lora_dropout: float, + init_lora_weights: bool, + use_rslora: bool, + lora_bias: bool, + **kwargs + ): + """Update LoRA layer with new adapter configuration. + + Args: + adapter_name: Name of the adapter. + r: LoRA rank. + lora_alpha: LoRA alpha scaling factor. + lora_dropout: Dropout probability. + init_lora_weights: Whether to initialize weights. + use_rslora: Use rank-stabilized LoRA. + lora_bias: Whether to add bias. + """ + if r <= 0: + raise ValueError(f'`r` should be a positive integer value but the value passed is {r}') + + self.r[adapter_name] = r + self.lora_alpha[adapter_name] = lora_alpha + + if lora_dropout > 0.0: + lora_dropout_layer = nn.Dropout(p=lora_dropout) + else: + lora_dropout_layer = nn.Identity() + + self.lora_dropout[adapter_name] = lora_dropout_layer + + # Build LoRA A and B matrices with proper parallelism + kwargs = { + 'skip_bias_add': False, + 'init_method': self.config.init_method, + 'config': self.config, + 'is_expert': self.is_expert, + } + if mcore_013: + kwargs['tp_group'] = self.base_layer.tp_group + + if isinstance(self.base_layer, TopKRouter): + # Router layer - no parallelism needed + router_shape = self.base_layer.weight.shape + lora_a = TELinear( + input_size=router_shape[1], + output_size=r, + bias=lora_bias, + parallel_mode=None, + skip_weight_param_allocation=False, + **kwargs, + ) + lora_b = TELinear( + input_size=r, + output_size=router_shape[0], + bias=lora_bias, + parallel_mode=None, + skip_weight_param_allocation=False, + **kwargs, + ) + elif self.is_parallel_a: + # Row parallel layer - LoRA A is parallel, LoRA B is not + in_features = self.in_features * self.tp_size + if self.is_grouped: + lora_a = TERowParallelGroupedLinear( + num_gemms=self.base_layer.num_gemms, + input_size=in_features, + output_size=r, + bias=False, + **kwargs, + ) + lora_b = TEGroupedLinear( + num_gemms=self.base_layer.num_gemms, + input_size=r, + output_size=self.out_features, + bias=lora_bias, + parallel_mode=None, + **kwargs, + ) + else: + lora_a = TERowParallelLinear( + input_size=in_features, + output_size=r, + bias=False, + input_is_parallel=True, + **kwargs, + ) + lora_b = TELinear( + input_size=r, + output_size=self.out_features, + bias=lora_bias, + parallel_mode=None, + skip_weight_param_allocation=False, + **kwargs, + ) + lora_a.parallel_mode = self.base_layer.parallel_mode + else: + # Column parallel layer - LoRA A is not parallel, LoRA B is parallel + out_features = self.out_features * self.tp_size + if self.is_grouped: + lora_a = TEGroupedLinear( + num_gemms=self.base_layer.num_gemms, + input_size=self.in_features, + output_size=r, + bias=lora_bias, + parallel_mode=None, + **kwargs + ) + lora_b = TEColumnParallelGroupedLinear( + num_gemms=self.base_layer.num_gemms, + input_size=r, + output_size=out_features, + bias=lora_bias, + **kwargs, + ) + else: + lora_a = TELinear( + input_size=self.in_features, + output_size=r, + bias=lora_bias, + parallel_mode=None, + skip_weight_param_allocation=False, + **kwargs + ) + lora_b = TEColumnParallelLinear( + input_size=r, + output_size=out_features, + bias=lora_bias, + gather_output=False, + **kwargs, + ) + lora_b.parallel_mode = self.base_layer.parallel_mode + + # Disable overlap for LoRA layers + for lora in [lora_a, lora_b]: + if isinstance(lora, (TERowParallelLinear, TEColumnParallelLinear)) and lora.parallel_mode is None: + lora.ub_overlap_rs_fprop = False + lora.ub_overlap_ag_dgrad = False + lora.ub_overlap_ag_fprop = False + lora.ub_overlap_rs_dgrad = False + + lora_a.sequence_parallel = False + lora_b.sequence_parallel = False + + self.lora_A[adapter_name] = lora_a + self.lora_B[adapter_name] = lora_b + + if hasattr(self, 'lora_bias'): + self.lora_bias[adapter_name] = lora_bias + + if use_rslora: + self.scaling[adapter_name] = lora_alpha / (r ** 0.5) + else: + self.scaling[adapter_name] = lora_alpha / r + + if init_lora_weights: + self.reset_lora_parameters(adapter_name, init_lora_weights) + + self._move_adapter_to_device_of_base_layer(adapter_name) + self.set_adapter(self.active_adapters) + + def reset_lora_parameters(self, adapter_name: str, init_lora_weights: bool): + """Reset LoRA parameters to initial values. + + Args: + adapter_name: Name of the adapter. + init_lora_weights: Initialization method. + """ + if init_lora_weights is False: + return + + if adapter_name in self.lora_A.keys(): + lora_a = self.lora_A[adapter_name] + lora_b = self.lora_B[adapter_name] + + if isinstance(lora_a, TEGroupedLinear): + weights_a = [getattr(lora_a, f'weight{i}') for i in range(lora_a.num_gemms)] + else: + weights_a = [lora_a.weight] + + if isinstance(lora_b, TEGroupedLinear): + weights_b = [getattr(lora_b, f'weight{i}') for i in range(lora_b.num_gemms)] + else: + weights_b = [lora_b.weight] + + for weight_a in weights_a: + if init_lora_weights is True: + nn.init.kaiming_uniform_(weight_a, a=math.sqrt(5)) + elif init_lora_weights.lower() == 'gaussian': + nn.init.normal_(weight_a, std=1 / self.r[adapter_name]) + else: + raise ValueError(f'Unknown initialization {init_lora_weights=}') + + for weight_b in weights_b: + nn.init.zeros_(weight_b) + + if adapter_name in self.lora_embedding_A.keys(): + nn.init.zeros_(self.lora_embedding_A[adapter_name]) + nn.init.normal_(self.lora_embedding_B[adapter_name]) + + @contextmanager + def _patch_router_gating(self): + """Context manager to patch router gating with LoRA.""" + origin_gating = self.base_layer.__class__.gating + + def gating(_self, x): + result = origin_gating(_self, x) + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + x = x.to(result.dtype) + + lora_result = F.linear(dropout(x), lora_A.weight.to(result.dtype)) + if isinstance(lora_result, tuple): + lora_result = lora_result[0] + lora_result = F.linear(lora_result, lora_B.weight.to(result.dtype)) + if isinstance(lora_result, tuple): + lora_result = lora_result[0] + lora_result = lora_result * scaling + + result = result + lora_result + return result + + self.base_layer.__class__.gating = gating + try: + yield + finally: + self.base_layer.__class__.gating = origin_gating + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any): + """Forward pass with LoRA adaptation. + + Args: + x: Input tensor. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + Tuple of (output tensor, bias). + """ + previous_dtype = x.dtype + if self.disable_adapters and self.merged: + self.unmerge() + + if isinstance(self.base_layer, TELayerNormColumnParallelLinear): + if self.disable_adapters or self.merged: + self.base_layer.return_layernorm_output = False + result, bias = self.base_layer(x, *args, **kwargs) + else: + self.base_layer.return_layernorm_output = True + (result, x), bias = self.base_layer(x, *args, **kwargs) + elif isinstance(self.base_layer, (TELinear, TEGroupedLinear)): + result, bias = self.base_layer(x, *args, **kwargs) + elif isinstance(self.base_layer, TopKRouter): + with self._patch_router_gating(): + result, bias = self.base_layer(x, *args, **kwargs) + else: + raise ValueError(f'Unsupported base layer type: {type(self.base_layer)}') + + if not isinstance(self.base_layer, TopKRouter) and not self.disable_adapters and not self.merged: + if self.sequence_parallel and self.base_layer.parallel_mode == 'column': + x = gather_from_sequence_parallel_region(x) + + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + dtype = lora_A.weight0.dtype if isinstance(lora_A, TEGroupedLinear) else lora_A.weight.dtype + x = x.to(dtype) + + lora_result = lora_A(dropout(x), *args, **kwargs) if isinstance(lora_A, TEGroupedLinear) else lora_A(dropout(x)) + if isinstance(lora_result, tuple): + lora_result = lora_result[0] + + lora_result = lora_B(lora_result, *args, **kwargs) if isinstance(lora_B, TEGroupedLinear) else lora_B(lora_result) + if isinstance(lora_result, tuple): + lora_result = lora_result[0] + + lora_result = lora_result * scaling + + if self.sequence_parallel and self.base_layer.parallel_mode == 'row': + lora_result = scatter_to_sequence_parallel_region(lora_result) + + result = result + lora_result + + result = result.to(previous_dtype) + return result, bias + + def sharded_state_dict( + self, + prefix: str = '', + sharded_offsets: Tuple[Tuple[int, int, int]] = (), + metadata: Optional[dict] = None, + ) -> ShardedStateDict: + """Get sharded state dict for distributed checkpointing. + + Args: + prefix: Key prefix. + sharded_offsets: Sharding offsets. + metadata: Additional metadata. + + Returns: + Sharded state dictionary. + """ + from ..utils import tuners_sharded_state_dict + + sharded_state_dict = tuners_sharded_state_dict(self, prefix, sharded_offsets, metadata) + + if prefix.endswith('linear_fc1.'): + if isinstance(self.base_layer, TEGroupedLinear) and self.config.gated_linear_unit: + num_global_experts = ( + parallel_state.get_expert_model_parallel_world_size() * self.base_layer.num_gemms + ) + local_expert_indices_offset = ( + parallel_state.get_expert_model_parallel_rank() * self.base_layer.num_gemms + ) + ep_axis = len(sharded_offsets) + for i in range(self.base_layer.num_gemms): + new_sharded_offsets = ( + *sharded_offsets, + (ep_axis, local_expert_indices_offset + i, num_global_experts), + ) + for k in (f'{prefix}base_layer.weight{i}', f'{prefix}base_layer.bias{i}'): + if k in sharded_state_dict: + sharded_state_dict[k] = apply_swiglu_sharded_factory( + sharded_state_dict[k], new_sharded_offsets + ) + else: + for k, v in sharded_state_dict.items(): + if k in [f'{prefix}base_layer.weight', f'{prefix}base_layer.bias']: + sharded_state_dict[k] = apply_swiglu_sharded_factory( + sharded_state_dict[k], sharded_offsets + ) + return sharded_state_dict + + def get_delta_weights(self, adapter: str) -> List[torch.Tensor]: + """Compute the delta weight for the given adapter. + + Args: + adapter: The name of the adapter. + + Returns: + List of delta weight tensors. + """ + lora_A = self.lora_A[adapter] + lora_B = self.lora_B[adapter] + + if self.is_grouped: + weight_A = [getattr(lora_A, f'weight{i}') for i in range(lora_A.num_gemms)] + weight_B = [getattr(lora_B, f'weight{i}') for i in range(lora_B.num_gemms)] + else: + weight_A = [self.lora_A[adapter].weight] + weight_B = [self.lora_B[adapter].weight] + + output_tensor = [] + assert len(weight_A) == len(weight_B) + + for i in range(len(weight_B)): + output_tensor.append( + transpose(weight_B[i] @ weight_A[i], self.fan_in_fan_out) * self.scaling[adapter] + ) + + return output_tensor + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + """Merge the active adapter weights into the base weights. + + Args: + safe_merge: If True, check for NaNs before merging. + adapter_names: List of adapter names to merge. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + return + + base_layer = self.get_base_layer() + origin_device = base_layer.weight0.device if self.is_grouped else base_layer.weight.device + + if origin_device.type == 'cpu': + device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu' + self.to(device=device) + + for active_adapter in adapter_names: + if active_adapter in self.lora_A.keys(): + if self.is_grouped: + orig_weights = [getattr(base_layer, f'weight{i}') for i in range(base_layer.num_gemms)] + else: + orig_weights = [base_layer.weight] + + if safe_merge: + orig_weights = [weight.data.clone() for weight in orig_weights] + delta_weights = self.get_delta_weights(active_adapter) + for orig_weight, delta_weight in zip(orig_weights, delta_weights): + orig_weight += delta_weight + if not all(torch.isfinite(orig_weights[i]).all() for i in range(len(orig_weights))): + raise ValueError( + f'NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken' + ) + if self.is_grouped: + for i in range(base_layer.num_gemms): + weight = getattr(base_layer, f'weight{i}') + weight.data = orig_weights[i] + else: + base_layer.weight.data = orig_weights[0] + else: + delta_weights = self.get_delta_weights(active_adapter) + for orig_weight, delta_weight in zip(orig_weights, delta_weights): + orig_weight.data += delta_weight + + self.merged_adapters.append(active_adapter) + + if origin_device.type == 'cpu': + self.to(device=origin_device) + + def unmerge(self) -> None: + """Unmerge all merged adapter weights from the base weights.""" + if not self.merged: + return + + base_layer = self.get_base_layer() + origin_device = base_layer.weight0.device if self.is_grouped else base_layer.weight.device + + if origin_device.type == 'cpu': + device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu' + self.to(device=device) + + for active_adapter in self.merged_adapters: + if active_adapter in self.lora_A.keys(): + if self.is_grouped: + orig_weights = [getattr(base_layer, f'weight{i}') for i in range(base_layer.num_gemms)] + else: + orig_weights = [base_layer.weight] + + delta_weights = self.get_delta_weights(active_adapter) + for orig_weight, delta_weight in zip(orig_weights, delta_weights): + orig_weight.data -= delta_weight + + self.merged_adapters = [] + + if origin_device.type == 'cpu': + self.to(device=origin_device) + + +def dispatch_megatron( + target: torch.nn.Module, + adapter_name: str, + lora_config, + **kwargs: Any, +) -> Optional[torch.nn.Module]: + """Dispatch function to replace Megatron linear layers with LoRA layers. + + Args: + target: The target module to potentially replace. + adapter_name: Name of the LoRA adapter. + lora_config: LoRA configuration. + **kwargs: Additional arguments for LoraParallelLinear. + + Returns: + LoraParallelLinear if target is a compatible layer, None otherwise. + """ + new_module = None + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + linear_cls = (TELayerNormColumnParallelLinear, TELinear, TEGroupedLinear, TopKRouter) + if isinstance(target_base_layer, linear_cls): + new_module = LoraParallelLinear(base_layer=target, adapter_name=adapter_name, **kwargs) + + return new_module + + +# Register dispatch function with PEFT +try: + model.dispatch_megatron = dispatch_megatron +except Exception: + pass + diff --git a/src/twinkle/megatron/utils.py b/src/twinkle/megatron/utils.py new file mode 100644 index 00000000..d53bbf7b --- /dev/null +++ b/src/twinkle/megatron/utils.py @@ -0,0 +1,1034 @@ +# Copyright (c) twinkle authors. All rights reserved. +# Code reference: Adapted from ms-swift with modifications for twinkle's multi-tenant architecture. +# Original code markers: Functions marked with [SWIFT] are adapted from swift, +# Functions marked with [TWINKLE] are original implementations. +"""Utility functions for Megatron-Core integration.""" +from contextlib import contextmanager +from typing import Any, Dict, List, Optional, Tuple +import threading + +import torch +import torch.nn as nn +import torch.distributed as dist + +import megatron.core +from megatron.core import parallel_state as mpu +from megatron.core.extensions.transformer_engine import ( + TEGroupedLinear, TELayerNormColumnParallelLinear, TELinear +) +from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding +from megatron.core.transformer.moe.router import TopKRouter +from megatron.core.transformer.transformer_block import get_num_layers_to_build +from megatron.core.transformer.transformer_layer import get_transformer_layer_offset +from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint, sharded_state_dict_default +from packaging import version +from peft import LoraConfig, get_peft_model + +mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') + + +# ============================================================================= +# [SWIFT] Config mapping from HuggingFace to Megatron - adapted from swift +# ============================================================================= +CONFIG_MAPPING = { + 'num_layers': ['num_hidden_layers'], + 'hidden_size': ['hidden_size'], + 'mlp_ffn_hidden_size': ['intermediate_size_mlp'], + 'ffn_hidden_size': ['intermediate_size'], + 'num_attention_heads': ['num_attention_heads'], + 'num_query_groups': ['num_key_value_heads'], + 'max_position_embeddings': ['max_position_embeddings'], + 'norm_epsilon': ['rms_norm_eps'], + 'rotary_base': ['rope_theta'], + 'padded_vocab_size': ['vocab_size'], + 'attention_dropout': ['attention_dropout'], + 'untie_embeddings_and_output_weights': ['tie_word_embeddings'], + 'swiglu': ['hidden_act'], + 'add_qkv_bias': ['attention_bias', 'qkv_bias', 'use_bias'], + 'disable_bias_linear': ['mlp_bias'], + 'kv_channels': ['head_dim', 'v_head_dim'], + 'architectures': ['architectures'], + # moe + 'moe_ffn_hidden_size': ['moe_intermediate_size'], + 'moe_shared_expert_intermediate_size': ['shared_expert_intermediate_size'], + 'moe_router_topk': ['num_experts_per_tok', 'moe_topk', 'moe_k'], + 'moe_router_num_groups': ['n_group'], + 'moe_router_group_topk': ['topk_group'], + 'num_experts': ['num_experts', 'n_routed_experts', 'moe_num_experts', 'num_local_experts'], + 'moe_router_pre_softmax': ['norm_topk_prob'], + # deepseek + 'q_lora_rank': ['q_lora_rank'], + 'kv_lora_rank': ['kv_lora_rank'], + 'moe_router_score_function': ['scoring_func'], + 'moe_router_bias_update_rate': ['aux_loss_alpha'], + 'qk_head_dim': ['qk_nope_head_dim'], + 'qk_pos_emb_head_dim': ['qk_rope_head_dim'], + 'moe_router_topk_scaling_factor': ['routed_scaling_factor'], + 'qk_layernorm': ['use_qk_norm'], + # other + 'original_max_position_embeddings': ['original_max_position_embeddings'], + 'partial_rotary_factor': ['partial_rotary_factor'], + 'first_k_dense_replace': ['first_k_dense_replace', 'moe_layer_start_index'], + 'n_shared_experts': ['n_shared_experts', 'num_shared_expert', 'moe_num_shared_experts'], + 'window_size': ['sliding_window'], + 'layer_types': ['layer_types'], +} + + +# ============================================================================= +# [TWINKLE] Multi-tenant Process Group Management +# ============================================================================= +class TenantProcessGroupManager: + """Manager for multi-tenant process groups. + + [TWINKLE] This is an original implementation for twinkle's multi-tenant architecture. + + In a multi-tenant scenario, multiple users may share the same base model in a single + process, each with their own LoRA adapters. To avoid communication interference between + tenants, we need to maintain separate process groups for each tenant. + + This class provides: + 1. Per-tenant process group isolation + 2. Context managers to temporarily switch active process groups + 3. Patching of Megatron's communication operations to use tenant-specific groups + + Example: + # Create tenant-specific groups + manager = TenantProcessGroupManager() + manager.register_tenant('user_1', tp_ranks=[0, 1], dp_ranks=[0, 2]) + manager.register_tenant('user_2', tp_ranks=[2, 3], dp_ranks=[1, 3]) + + # Use tenant context for operations + with manager.tenant_context('user_1'): + # All Megatron communications will use user_1's process groups + model.forward(...) + """ + + _instance = None + _lock = threading.Lock() + + def __new__(cls): + """Singleton pattern for global access.""" + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self): + if self._initialized: + return + self._initialized = True + + # Tenant ID -> Process Groups mapping + self._tenant_groups: Dict[str, Dict[str, dist.ProcessGroup]] = {} + # Current active tenant (thread-local) + self._active_tenant = threading.local() + # Original Megatron parallel state functions (for patching) + self._original_functions = {} + # Whether patching is active + self._patched = False + + def register_tenant( + self, + tenant_id: str, + tp_ranks: Optional[List[int]] = None, + pp_ranks: Optional[List[int]] = None, + dp_ranks: Optional[List[int]] = None, + ep_ranks: Optional[List[int]] = None, + cp_ranks: Optional[List[int]] = None, + ) -> None: + """Register a tenant with specific process group ranks. + + Args: + tenant_id: Unique identifier for the tenant. + tp_ranks: Ranks for tensor parallel group. + pp_ranks: Ranks for pipeline parallel group. + dp_ranks: Ranks for data parallel group. + ep_ranks: Ranks for expert parallel group. + cp_ranks: Ranks for context parallel group. + """ + if tenant_id in self._tenant_groups: + return # Already registered + + groups = {} + + # Create process groups for each parallelism dimension + if tp_ranks and len(tp_ranks) > 1: + groups['tp'] = dist.new_group(tp_ranks) + if pp_ranks and len(pp_ranks) > 1: + groups['pp'] = dist.new_group(pp_ranks) + if dp_ranks and len(dp_ranks) > 1: + groups['dp'] = dist.new_group(dp_ranks) + if ep_ranks and len(ep_ranks) > 1: + groups['ep'] = dist.new_group(ep_ranks) + if cp_ranks and len(cp_ranks) > 1: + groups['cp'] = dist.new_group(cp_ranks) + + self._tenant_groups[tenant_id] = groups + + def unregister_tenant(self, tenant_id: str) -> None: + """Unregister a tenant and destroy its process groups. + + Args: + tenant_id: Tenant to unregister. + """ + if tenant_id in self._tenant_groups: + groups = self._tenant_groups.pop(tenant_id) + for group in groups.values(): + dist.destroy_process_group(group) + + def get_tenant_group(self, tenant_id: str, group_type: str) -> Optional[dist.ProcessGroup]: + """Get process group for a tenant. + + Args: + tenant_id: Tenant identifier. + group_type: Type of group ('tp', 'pp', 'dp', 'ep', 'cp'). + + Returns: + Process group or None if not found. + """ + if tenant_id in self._tenant_groups: + return self._tenant_groups[tenant_id].get(group_type) + return None + + @property + def active_tenant(self) -> Optional[str]: + """Get the currently active tenant ID.""" + return getattr(self._active_tenant, 'id', None) + + @contextmanager + def tenant_context(self, tenant_id: str): + """Context manager to set active tenant for communications. + + All Megatron communication operations within this context will use + the tenant-specific process groups. This includes: + + - Tensor Parallel (TP): get_tensor_model_parallel_group/rank/world_size + - Data Parallel (DP): get_data_parallel_group/rank/world_size + - Pipeline Parallel (PP): get_pipeline_model_parallel_group/rank/world_size, + is_pipeline_first_stage, is_pipeline_last_stage + - Expert Parallel (EP): get_expert_model_parallel_group/rank/world_size + - Context Parallel (CP): get_context_parallel_group/rank/world_size + + Args: + tenant_id: Tenant to activate. + + Example: + manager = get_tenant_manager() + manager.register_tenant('user_1', tp_ranks=[0, 1], dp_ranks=[0, 2]) + + with manager.tenant_context('user_1'): + # All Megatron communications use user_1's groups + output = model.forward(input_ids) + """ + old_tenant = self.active_tenant + self._active_tenant.id = tenant_id + + # Apply all patches if not already done + if not self._patched: + self._patch_megatron_parallel_state() + self._patch_tensor_parallel_comms() + self._patch_expert_parallel_comms() + self._patch_context_parallel_comms() + + try: + yield + finally: + self._active_tenant.id = old_tenant + + def _patch_megatron_parallel_state(self) -> None: + """Patch Megatron's parallel_state to use tenant-specific groups. + + This patches the following functions for full TP/PP/DP/EP/CP support: + - get_tensor_model_parallel_group / get_tensor_model_parallel_world_size / get_tensor_model_parallel_rank + - get_data_parallel_group / get_data_parallel_world_size / get_data_parallel_rank + - get_pipeline_model_parallel_group / get_pipeline_model_parallel_world_size / get_pipeline_model_parallel_rank + - get_expert_model_parallel_group / get_expert_model_parallel_world_size / get_expert_model_parallel_rank + - get_context_parallel_group / get_context_parallel_world_size / get_context_parallel_rank + """ + if self._patched: + return + + # Save original functions + self._original_functions = { + # TP functions + 'get_tensor_model_parallel_group': mpu.get_tensor_model_parallel_group, + 'get_tensor_model_parallel_world_size': mpu.get_tensor_model_parallel_world_size, + 'get_tensor_model_parallel_rank': mpu.get_tensor_model_parallel_rank, + # DP functions + 'get_data_parallel_group': mpu.get_data_parallel_group, + 'get_data_parallel_world_size': mpu.get_data_parallel_world_size, + 'get_data_parallel_rank': mpu.get_data_parallel_rank, + # PP functions + 'get_pipeline_model_parallel_group': mpu.get_pipeline_model_parallel_group, + 'get_pipeline_model_parallel_world_size': mpu.get_pipeline_model_parallel_world_size, + 'get_pipeline_model_parallel_rank': mpu.get_pipeline_model_parallel_rank, + 'is_pipeline_first_stage': mpu.is_pipeline_first_stage, + 'is_pipeline_last_stage': mpu.is_pipeline_last_stage, + # EP functions + 'get_expert_model_parallel_group': mpu.get_expert_model_parallel_group, + 'get_expert_model_parallel_world_size': mpu.get_expert_model_parallel_world_size, + 'get_expert_model_parallel_rank': mpu.get_expert_model_parallel_rank, + # CP functions + 'get_context_parallel_group': mpu.get_context_parallel_group, + 'get_context_parallel_world_size': mpu.get_context_parallel_world_size, + 'get_context_parallel_rank': mpu.get_context_parallel_rank, + } + + manager = self + + def _make_group_getter(group_type: str, original_func_name: str): + """Create patched group getter function.""" + def patched_func(*args, **kwargs): + tenant = manager.active_tenant + if tenant and tenant in manager._tenant_groups: + group = manager.get_tenant_group(tenant, group_type) + if group is not None: + return group + return manager._original_functions[original_func_name](*args, **kwargs) + return patched_func + + def _make_world_size_getter(group_type: str, original_func_name: str): + """Create patched world_size getter function.""" + def patched_func(*args, **kwargs): + tenant = manager.active_tenant + if tenant and tenant in manager._tenant_groups: + group = manager.get_tenant_group(tenant, group_type) + if group is not None: + return dist.get_world_size(group) + return manager._original_functions[original_func_name](*args, **kwargs) + return patched_func + + def _make_rank_getter(group_type: str, original_func_name: str): + """Create patched rank getter function.""" + def patched_func(*args, **kwargs): + tenant = manager.active_tenant + if tenant and tenant in manager._tenant_groups: + group = manager.get_tenant_group(tenant, group_type) + if group is not None: + return dist.get_rank(group) + return manager._original_functions[original_func_name](*args, **kwargs) + return patched_func + + # Apply patches for TP + mpu.get_tensor_model_parallel_group = _make_group_getter('tp', 'get_tensor_model_parallel_group') + mpu.get_tensor_model_parallel_world_size = _make_world_size_getter('tp', 'get_tensor_model_parallel_world_size') + mpu.get_tensor_model_parallel_rank = _make_rank_getter('tp', 'get_tensor_model_parallel_rank') + + # Apply patches for DP + mpu.get_data_parallel_group = _make_group_getter('dp', 'get_data_parallel_group') + mpu.get_data_parallel_world_size = _make_world_size_getter('dp', 'get_data_parallel_world_size') + mpu.get_data_parallel_rank = _make_rank_getter('dp', 'get_data_parallel_rank') + + # Apply patches for PP + mpu.get_pipeline_model_parallel_group = _make_group_getter('pp', 'get_pipeline_model_parallel_group') + mpu.get_pipeline_model_parallel_world_size = _make_world_size_getter('pp', 'get_pipeline_model_parallel_world_size') + mpu.get_pipeline_model_parallel_rank = _make_rank_getter('pp', 'get_pipeline_model_parallel_rank') + + # Patch is_pipeline_first/last_stage + def patched_is_pipeline_first_stage(*args, **kwargs): + tenant = manager.active_tenant + if tenant and tenant in manager._tenant_groups: + group = manager.get_tenant_group(tenant, 'pp') + if group is not None: + return dist.get_rank(group) == 0 + return manager._original_functions['is_pipeline_first_stage'](*args, **kwargs) + + def patched_is_pipeline_last_stage(*args, **kwargs): + tenant = manager.active_tenant + if tenant and tenant in manager._tenant_groups: + group = manager.get_tenant_group(tenant, 'pp') + if group is not None: + return dist.get_rank(group) == dist.get_world_size(group) - 1 + return manager._original_functions['is_pipeline_last_stage'](*args, **kwargs) + + mpu.is_pipeline_first_stage = patched_is_pipeline_first_stage + mpu.is_pipeline_last_stage = patched_is_pipeline_last_stage + + # Apply patches for EP + mpu.get_expert_model_parallel_group = _make_group_getter('ep', 'get_expert_model_parallel_group') + mpu.get_expert_model_parallel_world_size = _make_world_size_getter('ep', 'get_expert_model_parallel_world_size') + mpu.get_expert_model_parallel_rank = _make_rank_getter('ep', 'get_expert_model_parallel_rank') + + # Apply patches for CP + mpu.get_context_parallel_group = _make_group_getter('cp', 'get_context_parallel_group') + mpu.get_context_parallel_world_size = _make_world_size_getter('cp', 'get_context_parallel_world_size') + mpu.get_context_parallel_rank = _make_rank_getter('cp', 'get_context_parallel_rank') + + self._patched = True + + def unpatch_megatron_parallel_state(self) -> None: + """Restore original Megatron parallel_state functions.""" + if not self._patched: + return + + for name, func in self._original_functions.items(): + setattr(mpu, name, func) + + self._patched = False + self._original_functions = {} + + def _patch_tensor_parallel_comms(self) -> None: + """Patch tensor parallel communication operations. + + This patches critical TP communication functions in megatron.core.tensor_parallel: + - mappings.copy_to_tensor_model_parallel_region + - mappings.reduce_from_tensor_model_parallel_region + - mappings.scatter_to_tensor_model_parallel_region + - mappings.gather_from_tensor_model_parallel_region + """ + try: + from megatron.core.tensor_parallel import mappings + except ImportError: + return + + if hasattr(self, '_tp_comms_patched') and self._tp_comms_patched: + return + + # Save original functions + self._original_tp_functions = {} + + # The mappings module uses get_tensor_model_parallel_group() internally, + # which we've already patched. No additional patches needed here. + # The patched group getters will be used automatically. + + self._tp_comms_patched = True + + def _patch_expert_parallel_comms(self) -> None: + """Patch expert parallel communication operations for MoE models. + + For MoE models, expert parallel communications use: + - get_expert_model_parallel_group + - get_expert_tensor_parallel_group (if using expert tensor parallelism) + + Since we've patched the group getters, the communications will + automatically use tenant-specific groups. + """ + # Expert parallel communications use the patched group getters + # No additional patches needed + pass + + def _patch_context_parallel_comms(self) -> None: + """Patch context parallel communication operations. + + Context parallelism communications include: + - Ring attention communications + - CP all-to-all operations + + These use get_context_parallel_group() which we've patched. + """ + # CP communications use the patched group getters + # No additional patches needed + pass + + +# Global instance for easy access +_tenant_manager: Optional[TenantProcessGroupManager] = None + + +def get_tenant_manager() -> TenantProcessGroupManager: + """Get the global tenant process group manager. + + + Returns: + The singleton TenantProcessGroupManager instance. + """ + global _tenant_manager + if _tenant_manager is None: + _tenant_manager = TenantProcessGroupManager() + return _tenant_manager + + +# ============================================================================= +# [SWIFT] Layer finding utilities - adapted from swift +# ============================================================================= +def find_layers(model: nn.Module, cond_fn) -> List[str]: + """Find all layers in model matching condition function. + + [SWIFT] Adapted from swift. + + Args: + model: The model to search. + cond_fn: Callable(name, module) -> bool. + + Returns: + List of matching layer names. + """ + result = [] + for name, module in model.named_modules(): + if cond_fn(name, module): + result.append(name) + return result + + +def find_all_linears(model: nn.Module) -> List[str]: + """Find all linear layers suitable for LoRA in a Megatron model. + + [SWIFT] Adapted from swift. + + Args: + model: The Megatron model. + + Returns: + List of layer names suitable for LoRA. + """ + def _cond(name: str, module: nn.Module) -> bool: + if name == 'output_layer': + return False + if isinstance(module, (TELinear, TELayerNormColumnParallelLinear, TEGroupedLinear, nn.Linear)): + return True + return False + + return find_layers(model, _cond) + + +def find_router(model: nn.Module) -> List[str]: + """Find all MoE router layers in a Megatron model. + + [SWIFT] Adapted from swift. + + Args: + model: The Megatron model. + + Returns: + List of router layer names. + """ + return find_layers(model, lambda name, module: isinstance(module, TopKRouter)) + + +def find_embedding(model: nn.Module) -> List[str]: + """Find all embedding layers in a Megatron model. + + [SWIFT] Adapted from swift. + + Args: + model: The Megatron model. + + Returns: + List of embedding layer names. + """ + return find_layers(model, lambda name, module: isinstance(module, LanguageModelEmbedding)) + + +def get_target_modules(model: nn.Module, target_modules: List[str]) -> List[str]: + """Expand target module specifications to actual module names. + + [SWIFT] Adapted from swift. + + Args: + model: The Megatron model. + target_modules: List of target module specs, may include 'all-linear', etc. + + Returns: + Expanded list of target module names. + """ + result = target_modules.copy() + if 'all-linear' in result: + result.remove('all-linear') + result += find_all_linears(model) + if 'all-embedding' in result: + result.remove('all-embedding') + result += find_embedding(model) + if 'all-router' in result: + result.remove('all-router') + result += find_router(model) + return list(set(result)) + + +def set_linear_is_expert(model: nn.Module): + """Mark expert linear layers in MoE models. + + [SWIFT] Adapted from swift. + + Args: + model: The Megatron model. + """ + for name, module in model.named_modules(): + if '.local_experts.' in name and isinstance( + module, (TELinear, TELayerNormColumnParallelLinear) + ): + module.is_expert = True + elif isinstance(module, TEGroupedLinear): + module.is_expert = True + + +def deep_getattr(obj: Any, attr: str, default: Any = None) -> Any: + """Get nested attribute using dot notation. + + Args: + obj: The object. + attr: Dot-separated attribute path. + default: Default value if attribute not found. + + Returns: + The attribute value or default. + """ + try: + for a in attr.split('.'): + obj = getattr(obj, a) + return obj + except AttributeError: + return default + + +# ============================================================================= +# [SWIFT] Config conversion - adapted from swift with Qwen3 enhancements +# ============================================================================= +def _convert_hf_config(config, _internal_call: bool = False) -> Dict[str, Any]: + """Convert HuggingFace config to Megatron config dict. + + [SWIFT] Adapted from swift. + + Args: + config: HuggingFace model config. + _internal_call: Internal flag for recursion. + + Returns: + Megatron-compatible config dict. + """ + megatron_config = {} + for k, hf_keys in CONFIG_MAPPING.items(): + for hf_k in hf_keys: + if hasattr(config, hf_k): + hf_v = getattr(config, hf_k) + if hf_v is None: + continue + if k == 'rotary_base': + megatron_config[k] = int(hf_v) + elif k in {'untie_embeddings_and_output_weights', 'disable_bias_linear', 'moe_router_pre_softmax'}: + megatron_config[k] = not hf_v + elif k == 'swiglu': + if hf_v == 'silu': + megatron_config[k] = True + else: + if k == 'kv_lora_rank': + megatron_config['multi_latent_attention'] = True + elif k == 'architectures': + if _internal_call: + k = 'llm_architectures' + megatron_config[k] = hf_v + break + + # Handle nested configs + for key in ['text_config', 'llm_config', 'thinker_config']: + if hasattr(config, key): + megatron_config.update(_convert_hf_config(getattr(config, key), _internal_call=True)) + + # Compat llama3 rope scaling + if getattr(config, 'rope_scaling', None) is not None: + if isinstance(config.rope_scaling, int): + megatron_config['rope_scaling'] = {'factor': config.rope_scaling, 'type': 'linear'} + elif isinstance(config.rope_scaling, dict): + megatron_config['rope_scaling'] = config.rope_scaling + + return megatron_config + + +def convert_hf_config(config) -> Dict[str, Any]: + """Convert HuggingFace config to Megatron-compatible config. + + [SWIFT] Adapted from swift with Qwen3 specific handling. + + Args: + config: HuggingFace model config. + + Returns: + Megatron-compatible config dict. + """ + res = _convert_hf_config(config) + + # Process architectures + architectures = res.get('architectures') + if isinstance(architectures, list) and architectures: + architectures = architectures[0] + res['architectures'] = architectures + + llm_architectures = res.get('llm_architectures') or architectures + if isinstance(llm_architectures, list) and llm_architectures: + llm_architectures = llm_architectures[0] + res['llm_architectures'] = llm_architectures + + # Process MoE settings + first_k_dense_replace = res.pop('first_k_dense_replace', None) + n_shared_experts = res.pop('n_shared_experts', None) + + # ==== Qwen3 Dense Model specific settings ==== + if llm_architectures == 'Qwen3ForCausalLM': + res['qk_layernorm'] = True + # Qwen3 uses SwiGLU activation + res['swiglu'] = True + # Qwen3 typically doesn't use bias in linear layers + res['disable_bias_linear'] = True + + # ==== Qwen3 MoE Model specific settings ==== + if llm_architectures == 'Qwen3MoeForCausalLM': + res['qk_layernorm'] = True + res['swiglu'] = True + res['disable_bias_linear'] = True + # Qwen3 MoE uses shared expert gate + res['use_shared_expert_gate'] = True + # Remove ffn_hidden_size as MoE uses moe_ffn_hidden_size + res.pop('ffn_hidden_size', None) + + # DeepSeek models + if llm_architectures in {'DeepseekV2ForCausalLM', 'DeepseekV3ForCausalLM'}: + res['qk_layernorm'] = True + res['moe_router_load_balancing_type'] = 'seq_aux_loss' + res.pop('num_query_groups', None) + + # Handle rope scaling + rope_scaling = res.get('rope_scaling') or {} + if 'partial_rotary_factor' not in res and 'partial_rotary_factor' in rope_scaling: + res['partial_rotary_factor'] = rope_scaling['partial_rotary_factor'] + if rope_scaling.get('mrope_section') is not None: + res['position_embedding_type'] = 'mrope' + res['mrope_section'] = rope_scaling['mrope_section'] + + # MoE layer frequency + if first_k_dense_replace is not None: + res['moe_layer_freq'] = f'[0]*{first_k_dense_replace}+[1]*{res["num_layers"] - first_k_dense_replace}' + if res.get('moe_router_score_function', 'softmax') == 'sigmoid' and 'moe_router_enable_expert_bias' not in res: + res['moe_router_enable_expert_bias'] = True + if n_shared_experts is not None and 'moe_shared_expert_intermediate_size' not in res: + res['moe_shared_expert_intermediate_size'] = n_shared_experts * res.get('moe_ffn_hidden_size', res.get('ffn_hidden_size', 0)) + + return res + + +@contextmanager +def patch_deepcopy(): + """Context manager to handle tp_group in deepcopy operations. + + [SWIFT] Adapted from swift. + + WHY THIS IS NECESSARY: + ---------------------- + Megatron-Core's TransformerEngine linear layers (TELinear, TEColumnParallelLinear, etc.) + store a reference to their tensor parallel process group in the `tp_group` attribute. + + When PEFT's get_peft_model() is called, it internally uses copy.deepcopy() to create + copies of certain modules. However, torch.distributed.ProcessGroup objects cannot be + pickled or deepcopied because: + + 1. ProcessGroup objects contain native CUDA/NCCL handles that are process-specific + 2. These handles cannot be serialized and recreated in a different memory context + 3. Attempting to deepcopy them raises: "RuntimeError: Cannot pickle ProcessGroup" + + This patch temporarily sets tp_group to None during deepcopy, then restores it + after the copy is complete. This allows PEFT to work with Megatron modules while + preserving the correct process group references. + + USAGE: + ------ + ```python + with patch_deepcopy(): + model = get_peft_model(megatron_model, lora_config) + ``` + + Without this patch, the above code would fail with a pickling error. + """ + import copy + _origin_deepcopy = copy.deepcopy + + def new_deepcopy(x, *args, **kwargs): + if getattr(x, 'tp_group', None) is not None: + origin_tp_group = x.tp_group + x.tp_group = None + res = _origin_deepcopy(x, *args, **kwargs) + x.tp_group = origin_tp_group + res.tp_group = origin_tp_group + return res + else: + return _origin_deepcopy(x, *args, **kwargs) + + copy.deepcopy = new_deepcopy + try: + yield + finally: + copy.deepcopy = _origin_deepcopy + + +# ============================================================================= +# [SWIFT] Sharded state dict for tuners - adapted from swift +# ============================================================================= +def tuners_sharded_state_dict( + module: nn.Module, + prefix: str = '', + sharded_offsets: Tuple[Tuple[int, int, int]] = (), + metadata: Optional[dict] = None, +) -> Dict[str, Any]: + """Generate sharded state dict for PEFT tuners. + + [SWIFT] Adapted from swift. + + Args: + module: The module to generate state dict for. + prefix: Key prefix. + sharded_offsets: Sharding offsets for distributed checkpointing. + metadata: Additional metadata. + + Returns: + Sharded state dictionary. + """ + sharded_state_dict = {} + # Save parameters + module._save_to_state_dict(sharded_state_dict, '', keep_vars=True) + sharded_state_dict = make_sharded_tensors_for_checkpoint( + sharded_state_dict, prefix, sharded_offsets=sharded_offsets + ) + # Recurse into submodules + for name, child in module.named_children(): + if 'Dict' in child.__class__.__name__: + modules = child.named_children() + else: + modules = [(None, child)] + for n, m in modules: + _prefix = f'{prefix}{name}.' if n is None else f'{prefix}{name}.{n}.' + sharded_state_dict.update(sharded_state_dict_default(m, _prefix, sharded_offsets, metadata)) + return sharded_state_dict + + +def prepare_mcore_model( + model: nn.Module, + train_type: str = 'lora', + lora_config: Optional[Dict[str, Any]] = None, + freeze_parameters: Optional[List[str]] = None, + tenant_id: Optional[str] = None, +) -> nn.Module: + """Prepare Megatron-Core model for training. + + Args: + model: The Megatron model. + train_type: Training type ('full' or 'lora'). + lora_config: LoRA configuration dict. + freeze_parameters: List of parameter names to freeze. + tenant_id: Optional tenant ID for multi-tenant isolation. + + Returns: + Prepared model. + """ + # Set up tenant context if provided + context = contextmanager(lambda: (yield))() + if tenant_id is not None: + manager = get_tenant_manager() + context = manager.tenant_context(tenant_id) + + with context: + if train_type == 'full': + if freeze_parameters: + for name, param in model.named_parameters(): + if any(fp in name for fp in freeze_parameters): + param.requires_grad = False + elif train_type == 'lora': + set_linear_is_expert(model) + if lora_config is not None: + model = prepare_lora_model(model, lora_config) + return model + + +def prepare_lora_model( + model: nn.Module, + lora_config: Dict[str, Any], +) -> nn.Module: + """Add LoRA adapters to Megatron model. + + Args: + model: The Megatron model. + lora_config: LoRA configuration dict with keys: + - r: LoRA rank + - lora_alpha: LoRA alpha + - lora_dropout: Dropout rate + - target_modules: Target module names + - use_rslora: Use rank-stabilized LoRA + + Returns: + Model with LoRA adapters. + """ + set_linear_is_expert(model) + + target_modules = get_target_modules(model, lora_config.get('target_modules', ['all-linear'])) + + peft_config = LoraConfig( + task_type='CAUSAL_LM', + r=lora_config.get('r', 8), + lora_alpha=lora_config.get('lora_alpha', 32), + lora_dropout=lora_config.get('lora_dropout', 0.0), + target_modules=target_modules, + bias=lora_config.get('bias', 'none'), + use_rslora=lora_config.get('use_rslora', False), + ) + + with patch_deepcopy(): + model = get_peft_model(model, peft_config) + + return model + + +# ============================================================================= +# [SWIFT] Layer spec utilities - adapted from swift +# ============================================================================= +def get_local_layer_specs(config, layer_specs: List, vp_stage: Optional[int] = None): + """Get local layer specifications for current pipeline rank. + + [SWIFT] Adapted from swift. + + Args: + config: Megatron transformer config. + layer_specs: Full list of layer specifications. + vp_stage: Virtual pipeline stage index. + + Returns: + Local layer specifications for this rank. + """ + kwargs = {'vp_stage': vp_stage} if mcore_013 else {} + num_layers_to_build = get_num_layers_to_build(config, **kwargs) + + if getattr(config, 'pipeline_model_parallel_layout', None) is not None: + from megatron.core.transformer.enums import LayerType + local_layer_specs = [ + layer_specs[layer_id] for layer_id in config.pipeline_model_parallel_layout.get_layer_id_list( + layer_type=LayerType.decoder, **kwargs) + ] + else: + offset = get_transformer_layer_offset(config, **kwargs) + local_layer_specs = layer_specs[offset:offset + num_layers_to_build] + return local_layer_specs + + +def get_padding_to( + tensor_model_parallel_size: int = 1, + context_parallel_size: int = 1, + sequence_parallel: bool = False, + fp8_format: Optional[str] = None, + fp8_recipe: Optional[str] = None, + attention_backend: Optional[str] = None, +) -> Optional[int]: + """Get padding size for sequence length. + + Args: + tensor_model_parallel_size: TP size. + context_parallel_size: CP size. + sequence_parallel: Whether sequence parallel is enabled. + fp8_format: FP8 format if used. + fp8_recipe: FP8 recipe if used. + attention_backend: Attention backend type. + + Returns: + Padding size or None. + """ + padding_to = None + if tensor_model_parallel_size > 1 and sequence_parallel: + padding_to = tensor_model_parallel_size + if context_parallel_size > 1: + padding_to = (padding_to or 1) * context_parallel_size + origin_padding_to = padding_to + + if fp8_recipe == 'blockwise': + padding_to = (padding_to or 1) * 128 + elif fp8_format is not None: + padding_to = max((padding_to or 1) * 8, 16) + + if attention_backend == 'fused': + padding_to = max(padding_to or 1, ((origin_padding_to) or 1) * 64) + + return padding_to + + +# ============================================================================= +# [SWIFT] Forward step helper - adapted from swift +# ============================================================================= +def forward_step_helper(model: nn.Module, inputs: Dict[str, Any], config) -> Optional[torch.Tensor]: + """Helper for pipeline parallel forward step. + + Handles communication between pipeline stages. + + Args: + model: The model. + inputs: Input dict with position_ids, etc. + config: Configuration with parallel settings. + + Returns: + Output tensor for last stage, None otherwise. + """ + from megatron.core.inference.communication_utils import recv_from_prev_pipeline_rank_, send_to_next_pipeline_rank + + if mpu.is_pipeline_first_stage(): + micro_batch_size = 1 # use qkv_format 'thd' + if not getattr(config, 'padding_free', False): + micro_batch_size = config.micro_batch_size + seq_length = inputs['position_ids'].shape[-1] + if config.sequence_parallel: + seq_length //= mpu.get_tensor_model_parallel_world_size() + recv_shape_buffer = torch.tensor( + [seq_length, micro_batch_size, config.hidden_size], + device=torch.cuda.current_device(), + dtype=torch.int64 + ) + else: + recv_shape_buffer = torch.empty((3,), device=torch.cuda.current_device(), dtype=torch.int64) + recv_from_prev_pipeline_rank_(recv_shape_buffer) + + if not mpu.is_pipeline_last_stage(): + send_to_next_pipeline_rank(recv_shape_buffer) + shape = recv_shape_buffer.tolist() + + if not mpu.is_pipeline_first_stage(): + dtype = config.params_dtype + recv_buffer = torch.empty(shape, device=torch.cuda.current_device(), dtype=dtype) + recv_from_prev_pipeline_rank_(recv_buffer) + model.set_input_tensor(recv_buffer) + + output_tensor = model(**inputs) + + if not mpu.is_pipeline_last_stage(): + send_to_next_pipeline_rank(output_tensor) + output_tensor = None + + return output_tensor + + +class MegatronTrainerState: + """Lightweight trainer state for Megatron training. + + Provides compatibility with transformers TrainerState interface. + + Attributes: + global_step: The current training step. + max_steps: The total number of training steps. + """ + + def __init__(self, global_step: int = 0, max_steps: int = 0): + self.global_step = global_step + self.max_steps = max_steps + + def update(self, global_step: Optional[int] = None, max_steps: Optional[int] = None): + if global_step is not None: + self.global_step = global_step + if max_steps is not None: + self.max_steps = max_steps + + def __repr__(self) -> str: + return f'MegatronTrainerState(global_step={self.global_step}, max_steps={self.max_steps})' + + +def get_model_parameter_info(model: nn.Module) -> Dict[str, Any]: + """Get parameter count information for a model. + + Args: + model: The model. + + Returns: + Dict with total_params, trainable_params, frozen_params. + """ + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + frozen_params = total_params - trainable_params + + return { + 'total_params': total_params, + 'trainable_params': trainable_params, + 'frozen_params': frozen_params, + 'trainable_ratio': trainable_params / total_params if total_params > 0 else 0, + } diff --git a/src/twinkle/model/__init__.py b/src/twinkle/model/__init__.py index 749cf477..28c709d2 100644 --- a/src/twinkle/model/__init__.py +++ b/src/twinkle/model/__init__.py @@ -1,3 +1,4 @@ from .transformers import TransformersModel from .base import TwinkleModel from .multi_lora_transformers import MultiLoraTransformersModel +from .megatron import MegatronModel diff --git a/src/twinkle/model/megatron.py b/src/twinkle/model/megatron.py new file mode 100644 index 00000000..ce5c1830 --- /dev/null +++ b/src/twinkle/model/megatron.py @@ -0,0 +1,856 @@ +# Copyright (c) twinkle authors. All rights reserved. +"""Megatron-Core model wrapper for twinkle training framework.""" +import contextlib +import json +import re +from dataclasses import dataclass +from typing import Any, Dict, List, Literal, Optional, Type, Union + +import torch +import torch.nn as nn +import torch.distributed as dist +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler + +import twinkle +from twinkle import remote_class, remote_function, template, DeviceMesh +from twinkle.data_format import InputFeature, Trajectory +from twinkle.hub import HubOperation +from twinkle.loss import Loss, VocabParallelCrossEntropyLoss +from twinkle.processor import InputProcessor +from twinkle.template import Template +from twinkle.utils.plugin import Plugin +from .base import TwinkleModel +from .strategy import MegatronStrategy + +try: + import megatron.core + from megatron.core import parallel_state as mpu + from megatron.core.distributed import DistributedDataParallel as MegatronDDP + from packaging import version + MEGATRON_AVAILABLE = True + mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') +except ImportError: + MEGATRON_AVAILABLE = False + mcore_013 = False + + +@dataclass +class MegatronOptimizerGroup: + """Optimizer group for Megatron training. + + Similar to OptimizerGroup but adapted for Megatron's distributed training. + """ + adapter_name: str = None + adapter_config: Any = None + optimizer: Optimizer = None + lr_scheduler: LRScheduler = None + inputs: Dict[str, Any] = None + outputs: Dict[str, Any] = None + loss_instance: Loss = None + loss_value: Any = None + template: Template = None + processor: InputProcessor = None + gradient_accumulation_steps: int = 1 + cur_step: int = 0 + dp_group = None + + def do_grad_sync(self, gradient_accumulation_steps: Optional[int] = None) -> bool: + """Check if gradient synchronization should happen.""" + if gradient_accumulation_steps is None: + gradient_accumulation_steps = self.gradient_accumulation_steps + return self.cur_step % gradient_accumulation_steps == 0 and self.cur_step > 0 + + +_default_adapter_name = '' + + +def check_megatron_available(): + """Check if Megatron-Core is available.""" + if not MEGATRON_AVAILABLE: + raise ImportError( + "Megatron-Core is not installed. Please install it with: " + "pip install megatron-core" + ) + + +@remote_class(execute='all') +class MegatronModel(TwinkleModel, nn.Module): + """Megatron-Core model wrapper for twinkle training framework. + + Note: Uses execute='all' to create workers on all ranks, which is required + for Megatron's TP/DP parallelism where all ranks must participate in + collective operations like gradient all-reduce. + + This class provides a similar API to TransformersModel but uses Megatron-Core + as the training backend, supporting TP/PP/CP/EP parallelism. + + Args: + pretrained_model_name_or_path: HuggingFace model path or ID. + device_mesh: Twinkle DeviceMesh for distributed training. + tensor_model_parallel_size: Tensor parallel size. + pipeline_model_parallel_size: Pipeline parallel size. + context_parallel_size: Context parallel size. + expert_model_parallel_size: Expert parallel size. + sequence_parallel: Enable sequence parallelism. + mixed_precision: Mixed precision mode. + use_distributed_optimizer: Use Megatron's distributed optimizer. + **kwargs: Additional arguments passed to model initialization. + """ + + def __init__( + self, + pretrained_model_name_or_path: str, + device_mesh: Optional[DeviceMesh] = None, + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + context_parallel_size: int = 1, + expert_model_parallel_size: int = 1, + sequence_parallel: bool = False, + mixed_precision: Literal['no', 'fp16', 'bf16'] = 'bf16', + use_distributed_optimizer: bool = True, + load_weights: bool = True, + **kwargs, + ): + check_megatron_available() + nn.Module.__init__(self) + + self.model_id = pretrained_model_name_or_path + self.device_mesh = device_mesh + self.mixed_precision = mixed_precision + + # Create Megatron strategy + self.strategy = MegatronStrategy( + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, + context_parallel_size=context_parallel_size, + expert_model_parallel_size=expert_model_parallel_size, + sequence_parallel=sequence_parallel, + use_distributed_optimizer=use_distributed_optimizer, + mixed_precision=mixed_precision, + ) + + # Initialize parallel state + self.strategy.initialize() + + # Load HuggingFace config + model_path = HubOperation.download_model(pretrained_model_name_or_path) + self._load_hf_config(model_path) + + # Create Megatron model + self.model = self._create_megatron_model(model_path, load_weights, **kwargs) + + self._model_wrapped = False + # Use VocabParallelCrossEntropyLoss by default for Megatron + # This correctly handles vocab sharding in Tensor Parallelism + self.optimizer_group: Dict[str, MegatronOptimizerGroup] = { + _default_adapter_name: MegatronOptimizerGroup(loss_instance=VocabParallelCrossEntropyLoss()) + } + + def _load_hf_config(self, model_path: str): + """Load HuggingFace model config.""" + from transformers import AutoConfig + self.hf_config = AutoConfig.from_pretrained(model_path) + + def _create_megatron_model( + self, + model_path: str, + load_weights: bool = True, + **kwargs, + ) -> nn.Module: + """Create Megatron model from HuggingFace checkpoint. + + Args: + model_path: Path to HuggingFace model. + load_weights: Whether to load weights. + **kwargs: Additional arguments. + + Returns: + Megatron model on GPU. + """ + from twinkle.megatron.model.initializer import MegatronModelInitializer + + params_dtype = torch.bfloat16 + if self.mixed_precision == 'fp16': + params_dtype = torch.float16 + elif self.mixed_precision == 'no': + params_dtype = torch.float32 + + initializer = MegatronModelInitializer( + tp_size=self.strategy.tp_size, + pp_size=self.strategy.pp_size, + cp_size=self.strategy.cp_size, + ep_size=self.strategy.ep_size, + sequence_parallel=self.strategy.sequence_parallel, + params_dtype=params_dtype, + ) + + # Create model + model = initializer.create_gpt_model(self.hf_config, **kwargs) + + # Load weights + if load_weights: + initializer.load_from_hf(model, model_path, self.hf_config) + + model = self._move_model_to_gpu(model) + + return model + + def _move_model_to_gpu(self, model: nn.Module) -> nn.Module: + # Determine the target device based on local rank + local_rank = dist.get_rank() % torch.cuda.device_count() if dist.is_initialized() else 0 + device = torch.device(f'cuda:{local_rank}') + + # Move all parameters and buffers to GPU + model = model.to(device) + + return model + + def _lazy_wrap_model(self): + """Lazily wrap model with distributed wrapper.""" + if not self._model_wrapped: + # Find an optimizer from any adapter group (prefer default, then first available) + optimizer = None + optimizer_adapter = None + + if _default_adapter_name in self.optimizer_group: + optimizer = self.optimizer_group[_default_adapter_name].optimizer + optimizer_adapter = _default_adapter_name + else: + for name, group in self.optimizer_group.items(): + if group.optimizer is not None: + optimizer = group.optimizer + optimizer_adapter = name + break + + if optimizer is not None: + self.model, optimizer = self.strategy.wrap_model(self.model, optimizer) + self.optimizer_group[optimizer_adapter].optimizer = optimizer + self._model_wrapped = True + + @remote_function() + def forward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs): + """Forward pass with Megatron model. + + Args: + inputs: Model inputs. + **kwargs: Additional arguments including adapter_name. + + Returns: + Model outputs. + """ + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + self._lazy_wrap_model() + + # Encode inputs if needed + if isinstance(inputs, dict) and 'input_ids' not in inputs: + if optimizer_config.template is not None: + inputs = optimizer_config.template.encode(inputs) + if isinstance(inputs, list) and 'input_ids' not in inputs[0]: + if optimizer_config.template is not None: + inputs = optimizer_config.template.batch_encode(inputs) + + # Process inputs + processor: InputProcessor = optimizer_config.processor + if processor is not None: + inputs: Dict[str, Any] = processor(inputs) + + labels = inputs.pop('labels', None) + + # Forward through model + outputs = self._forward_step(inputs) + + inputs['labels'] = labels + optimizer_config.inputs = inputs + optimizer_config.outputs = outputs + return outputs + + def _forward_step(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """Execute forward step with pipeline parallelism support. + + Args: + inputs: Processed inputs. + + Returns: + Model outputs. + """ + # Handle pipeline parallelism + if self.strategy.pp_size > 1: + return self._forward_step_pipeline(inputs) + else: + return self._forward_step_simple(inputs) + + def _forward_step_simple(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """Simple forward step without pipeline parallelism.""" + model = self.strategy.unwrap_model(self.model) + + # Prepare inputs for Megatron + input_ids = inputs.get('input_ids') + attention_mask = inputs.get('attention_mask') + position_ids = inputs.get('position_ids') + + # Create position_ids if not provided + if position_ids is None and input_ids is not None: + position_ids = torch.arange( + input_ids.shape[1], + device=input_ids.device, + dtype=torch.long, + ).unsqueeze(0).expand(input_ids.shape[0], -1) + + # Forward pass + outputs = model( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + ) + + return {'logits': outputs} + + def _forward_step_pipeline(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """Forward step with pipeline parallelism.""" + from twinkle.megatron.utils import forward_step_helper + + model = self.strategy.unwrap_model(self.model) + + # Use pipeline forward helper + output = forward_step_helper( + model, + inputs, + model.config, + ) + + if output is not None: + return {'logits': output} + return {} + + @remote_function() + def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajectory]], **kwargs): + """Forward pass without gradient computation. + + Args: + inputs: Model inputs. + **kwargs: Additional arguments. + + Returns: + Model outputs. + """ + with torch.no_grad(): + return self.forward(inputs=inputs, **kwargs) + + @remote_function(collect='avg') + def calculate_loss(self, **kwargs): + """Calculate loss from forward outputs. + + Args: + **kwargs: Additional arguments including adapter_name. + + Returns: + Loss value as numpy array. + """ + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + loss_instance: Loss = optimizer_config.loss_instance + + inputs = optimizer_config.inputs + outputs = optimizer_config.outputs + + assert inputs is not None and outputs is not None, \ + 'Cannot calculate loss of empty inputs and outputs' + + loss_value = loss_instance(inputs, outputs, **kwargs) + optimizer_config.loss_value = loss_value + return loss_value.detach().cpu().float().numpy() + + @remote_function() + def backward(self, **kwargs): + """Backward pass. + + Args: + **kwargs: Additional arguments. + """ + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + loss_value = optimizer_config.loss_value + + assert loss_value is not None, 'Do forwarding and calculating loss before backward' + + _gas = optimizer_config.gradient_accumulation_steps + if 'gradient_accumulation_steps' in kwargs: + _gas = kwargs['gradient_accumulation_steps'] + + loss_value = loss_value / _gas + loss_value.backward() + optimizer_config.cur_step += 1 + + @remote_function(collect='avg') + def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs): + """Combined forward and backward pass. + + Args: + inputs: Model inputs. + **kwargs: Additional arguments. + + Returns: + Loss value. + """ + self.forward(inputs=inputs, **kwargs) + loss = self.calculate_loss(**kwargs) + self.backward(**kwargs) + return loss + + @remote_function(dispatch='all') + def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type: int = 2, **kwargs): + """Clip gradient norm. + + Args: + max_grad_norm: Maximum gradient norm. + norm_type: Type of norm to use. + **kwargs: Additional arguments. + + Returns: + Total norm of gradients. + """ + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + parameters = self._get_trainable_parameters(adapter_name).values() + + return torch.nn.utils.clip_grad_norm_( + parameters, max_grad_norm, norm_type=norm_type + ).detach().cpu().numpy() + + @remote_function(dispatch='all') + def step(self, **kwargs): + """Optimizer step. + + For PEFT models, gradients are NOT synchronized across DP ranks + because each DP replica trains independently with different data. + This is a common pattern for PEFT training where gradient averaging + is not strictly necessary. + + Note: Uses dispatch='all' to ensure all workers execute this method, + though gradient sync is disabled for PEFT models. + + Args: + **kwargs: Additional arguments. + """ + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + + if not optimizer_config.do_grad_sync(kwargs.get('gradient_accumulation_steps')): + return + + # Note: For PEFT/LoRA models, we skip gradient synchronization across DP ranks. + # Each DP replica trains independently. This avoids distributed communication + # complexity and is acceptable for most PEFT training scenarios. + # If gradient averaging is needed, use DDP-wrapped models instead. + + optimizer = optimizer_config.optimizer + assert optimizer is not None, 'Set optimizer correctly before stepping' + + optimizer.step(**kwargs) + + def _is_model_ddp_wrapped(self) -> bool: + """Check if model is wrapped with DDP. + + Returns: + True if model is wrapped with DDP (either Megatron DDP or PyTorch DDP). + """ + from torch.nn.parallel import DistributedDataParallel as TorchDDP + return isinstance(self.model, (MegatronDDP, TorchDDP)) + + def _get_unwrapped_model(self) -> nn.Module: + """Get the unwrapped model. + + Returns: + The base model without DDP wrapper. + """ + return self.strategy.unwrap_model(self.model) + + @remote_function(dispatch='all') + def zero_grad(self, **kwargs): + """Zero gradients. + + Args: + **kwargs: Additional arguments. + """ + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + + if not optimizer_config.do_grad_sync(kwargs.get('gradient_accumulation_steps')): + return + + optimizer = optimizer_config.optimizer + if optimizer is not None: + optimizer.zero_grad(**kwargs) + + @remote_function() + def lr_step(self, **kwargs): + """Learning rate scheduler step. + + Args: + **kwargs: Additional arguments. + """ + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + + if not optimizer_config.do_grad_sync(kwargs.get('gradient_accumulation_steps')): + return + + lr_scheduler = optimizer_config.lr_scheduler + if lr_scheduler is not None: + lr_scheduler.step(**kwargs) + + @remote_function() + def set_loss(self, loss_cls: Union[Type[Loss], str], **kwargs): + """Set loss function. + + Args: + loss_cls: Loss class or string name. + **kwargs: Additional arguments. + """ + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + + if isinstance(loss_cls, str): + if hasattr(twinkle.loss, loss_cls): + loss_cls = getattr(twinkle.loss, loss_cls) + else: + loss_cls = Plugin.load_plugin(loss_cls, Loss) + optimizer_config.loss_instance = loss_cls() + + @remote_function() + def set_optimizer(self, optimizer_cls: Union[Type[Optimizer], str], **kwargs): + """Set optimizer. + + Args: + optimizer_cls: Optimizer class or string name. + **kwargs: Additional arguments. + """ + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + + if isinstance(optimizer_cls, str): + if hasattr(torch.optim, optimizer_cls): + optimizer_cls = getattr(torch.optim, optimizer_cls) + else: + optimizer_cls = Plugin.load_plugin(optimizer_cls, Optimizer) + + optimizer_config.optimizer = optimizer_cls( + self._get_trainable_parameters(adapter_name).values(), **kwargs + ) + + def _get_trainable_parameters(self, adapter_name: str = _default_adapter_name) -> Dict[str, nn.Parameter]: + """Get trainable parameters. + + Args: + adapter_name: Name of adapter. + + Returns: + Dict mapping parameter names to parameters. + """ + is_default = adapter_name == _default_adapter_name + pattern = re.compile(rf'\.lora_\w+\.{re.escape(adapter_name)}\.') + + params = {} + model = self.strategy.unwrap_model(self.model) + for name, param in model.named_parameters(): + if param.requires_grad and (pattern.search(name) or is_default): + params[name] = param + return params + + @remote_function() + def set_lr_scheduler(self, scheduler_cls: Union[Type[LRScheduler], str], **kwargs): + """Set learning rate scheduler. + + Args: + scheduler_cls: Scheduler class or string name. + **kwargs: Additional arguments. + """ + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + + if isinstance(scheduler_cls, str): + if hasattr(torch.optim.lr_scheduler, scheduler_cls): + scheduler_cls = getattr(torch.optim.lr_scheduler, scheduler_cls) + else: + scheduler_cls = Plugin.load_plugin(scheduler_cls, LRScheduler) + + optimizer = optimizer_config.optimizer + assert optimizer is not None, 'Set optimizer before setting lr_scheduler' + optimizer_config.lr_scheduler = scheduler_cls(optimizer, **kwargs) + + @remote_function() + def save(self, output_dir: str, **kwargs): + """Save model checkpoint. + + Args: + output_dir: Output directory. + **kwargs: Additional arguments. + """ + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + save_format = kwargs.pop('save_format', 'hf') # 'hf' or 'megatron' + + if save_format == 'hf': + self._save_hf_format(output_dir, adapter_name) + else: + self._save_megatron_format(output_dir, adapter_name) + + self._save_tokenizer(output_dir, adapter_name) + + def _save_hf_format(self, output_dir: str, adapter_name: str): + """Save in HuggingFace format using swift's GPTBridge.""" + from twinkle.megatron.model.bridge import TwinkleBridgeAdapter + import os + + # Only save from last PP rank + if not self.strategy.is_pipeline_last_stage(): + return + + os.makedirs(output_dir, exist_ok=True) + + # Use TwinkleBridgeAdapter which wraps swift's GPTBridge + adapter = TwinkleBridgeAdapter( + hf_config=self.hf_config, + tp_size=self.strategy.tp_size, + pp_size=self.strategy.pp_size, + ep_size=self.strategy.ep_size, + model_path=self.pretrained_model_name_or_path, + ) + + # Use swift's bridge to save weights + adapter.save_weights([self.model], output_dir, is_peft_format=False) + + # Save config + self.hf_config.save_pretrained(output_dir) + + def _save_megatron_format(self, output_dir: str, adapter_name: str): + """Save in Megatron checkpoint format.""" + import os + os.makedirs(output_dir, exist_ok=True) + + model = self.strategy.unwrap_model(self.model) + state_dict = self._get_trainable_parameters(adapter_name) + + # Convert to CPU + cpu_state_dict = {k: v.cpu() for k, v in state_dict.items()} + + # Save with rank info for distributed checkpointing + rank = dist.get_rank() if dist.is_initialized() else 0 + checkpoint_path = os.path.join(output_dir, f'model_rank{rank}.pt') + torch.save(cpu_state_dict, checkpoint_path) + + def _save_tokenizer(self, output_dir: str, adapter_name: str = _default_adapter_name): + """Save tokenizer.""" + optimizer_config = self.optimizer_group.get(adapter_name) + if optimizer_config and optimizer_config.template: + optimizer_config.template.tokenizer.save_pretrained(output_dir) + + @remote_function(execute='first') + def get_state_dict(self, **kwargs): + """Get trainable state dict. + + Args: + **kwargs: Additional arguments. + + Returns: + State dict of trainable parameters. + """ + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + return self._get_trainable_parameters(adapter_name) + + _peft_patched = False + + @classmethod + def _patch_peft_for_megatron(cls): + """Patch PEFT's BaseTuner to handle Megatron's TransformerConfig. + + Megatron's TransformerConfig doesn't have a .get() method like HuggingFace + configs. This patch handles the AttributeError that occurs when PEFT tries + to check tie_word_embeddings. + + Reference: swift/swift/megatron/init.py::_patch_peft_BaseTuner + """ + if cls._peft_patched: + return + + from typing import List + import torch.nn as nn + from peft.tuners.tuners_utils import BaseTuner + + _origin_get_tied_target_modules = BaseTuner._get_tied_target_modules + + def _get_tied_target_modules(self, model: nn.Module) -> List[str]: + try: + return _origin_get_tied_target_modules(self, model) + except AttributeError: + # Megatron's TransformerConfig doesn't have .get() method + # Check share_embeddings_and_output_weights instead + tied_target_modules = [] + if getattr(model, 'share_embeddings_and_output_weights', False): + for target_module in self.targeted_module_names: + module_name = target_module.split('.')[-1] + if module_name in ['output_layer', 'embedding', 'word_embeddings']: + tied_target_modules.append(target_module) + return tied_target_modules + + BaseTuner._get_tied_target_modules = _get_tied_target_modules + cls._peft_patched = True + + @remote_function() + def add_adapter_to_model( + self, + adapter_name: str, + config_or_dir: Union[Any, str], + **kwargs, + ): + """Add LoRA adapter to model. + + Args: + adapter_name: Name of the adapter. + config_or_dir: LoRA config or path to saved adapter. + **kwargs: Additional arguments. + """ + from twinkle.megatron.utils import ( + prepare_lora_model, patch_deepcopy, get_target_modules, set_linear_is_expert + ) + + # Patch PEFT BaseTuner to handle Megatron's TransformerConfig + # which doesn't have a .get() method like HuggingFace configs + self._patch_peft_for_megatron() + + assert adapter_name, 'Use a non-empty adapter_name' + + model = self.strategy.unwrap_model(self.model) + + # Mark expert layers for MoE models + set_linear_is_expert(model) + + if isinstance(config_or_dir, str): + # Load from path + config_or_dir = HubOperation.download_model(config_or_dir) + from peft import PeftModel + model = PeftModel.from_pretrained( + model, config_or_dir, adapter_name=adapter_name, + is_trainable=kwargs.get('is_trainable', True) + ) + else: + # Create from config + from peft import LoraConfig, get_peft_model + + if not isinstance(config_or_dir, LoraConfig): + # Convert dict to LoraConfig + config_or_dir = LoraConfig(**config_or_dir) + + # Expand target_modules (e.g., 'all-linear' -> actual module names) + if config_or_dir.target_modules: + if isinstance(config_or_dir.target_modules, str): + target_modules = [config_or_dir.target_modules] + else: + target_modules = list(config_or_dir.target_modules) + + expanded_modules = get_target_modules(model, target_modules) + config_or_dir.target_modules = expanded_modules + + with patch_deepcopy(): + model = get_peft_model(model, config_or_dir, adapter_name=adapter_name) + + # Update model reference + if self._model_wrapped: + if isinstance(self.model, MegatronDDP): + self.model.module = model + else: + self.model = model + + # Create optimizer group for adapter + self.optimizer_group[adapter_name] = MegatronOptimizerGroup() + self.optimizer_group[adapter_name].adapter_name = adapter_name + self.optimizer_group[adapter_name].adapter_config = config_or_dir + self.optimizer_group[adapter_name].gradient_accumulation_steps = kwargs.get( + 'gradient_accumulation_steps', 1 + ) + + # Copy settings from default + default_config = self.optimizer_group.get(_default_adapter_name) + if default_config: + if default_config.template: + self.optimizer_group[adapter_name].template = default_config.template + if default_config.processor: + self.optimizer_group[adapter_name].processor = default_config.processor + + @remote_function() + def set_template(self, template_cls: Union[Type[template.Template], str], **kwargs): + """Set template for input encoding. + + Args: + template_cls: Template class or string name. + **kwargs: Additional arguments. + """ + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + + if isinstance(template_cls, str): + if hasattr(template, template_cls): + template_cls = getattr(template, template_cls) + else: + template_cls = Plugin.load_plugin(template_cls, template.Template) + optimizer_config.template = template_cls(self.model_id, **kwargs) + + @remote_function() + def set_processor(self, processor_cls: Union[Type[InputProcessor], str], **kwargs): + """Set input processor. + + Args: + processor_cls: Processor class or string name. + **kwargs: Additional arguments. + """ + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + + if isinstance(processor_cls, str): + if hasattr(twinkle.processor, processor_cls): + processor_cls = getattr(twinkle.processor, processor_cls) + else: + processor_cls = Plugin.load_plugin(processor_cls, InputProcessor) + optimizer_config.processor = processor_cls(device_mesh=self.device_mesh, **kwargs) + + @remote_function(execute='first') + def get_train_configs(self, **kwargs): + """Get training configuration summary. + + Args: + **kwargs: Additional arguments. + + Returns: + Configuration summary string. + """ + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + + expr = f'Backend: Megatron-Core\n' + expr += f'TP size: {self.strategy.tp_size}\n' + expr += f'PP size: {self.strategy.pp_size}\n' + expr += f'CP size: {self.strategy.cp_size}\n' + expr += f'EP size: {self.strategy.ep_size}\n' + expr += f'Sequence Parallel: {self.strategy.sequence_parallel}\n' + + if optimizer_config.adapter_config is not None: + config = optimizer_config.adapter_config.__dict__ + config = {key: str(value) for key, value in config.items() if value is not None} + expr += f'Adapter config:\n{json.dumps(config, indent=2, ensure_ascii=False)}\n' + + if optimizer_config.optimizer: + expr += f'Optimizer: {optimizer_config.optimizer.__class__.__name__}\n' + expr += f'Learning rate: {optimizer_config.optimizer.defaults.get("lr", "N/A")}\n' + if optimizer_config.lr_scheduler: + expr += f'LR scheduler: {optimizer_config.lr_scheduler.__class__.__name__}\n' + expr += f'Gradient accumulation steps: {optimizer_config.gradient_accumulation_steps}\n' + + return expr + + def __repr__(self): + return ( + f"MegatronModel(model_id='{self.model_id}', " + f"tp={self.strategy.tp_size}, pp={self.strategy.pp_size}, " + f"cp={self.strategy.cp_size}, ep={self.strategy.ep_size})" + ) + diff --git a/src/twinkle/model/strategy/__init__.py b/src/twinkle/model/strategy/__init__.py index cf9126a8..bb4d4ce6 100644 --- a/src/twinkle/model/strategy/__init__.py +++ b/src/twinkle/model/strategy/__init__.py @@ -1,2 +1,3 @@ from .base import TrainStrategy from .accelerate import AccelerateStrategy +from .megatron import MegatronStrategy diff --git a/src/twinkle/model/strategy/megatron.py b/src/twinkle/model/strategy/megatron.py new file mode 100644 index 00000000..487c77b9 --- /dev/null +++ b/src/twinkle/model/strategy/megatron.py @@ -0,0 +1,638 @@ +# Copyright (c) twinkle authors. All rights reserved. +"""Megatron training strategy for distributed model parallelism.""" +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.distributed as dist + +from .base import TrainStrategy + +try: + from twinkle import DeviceMesh +except ImportError: + DeviceMesh = None + +try: + import megatron.core + from megatron.core import parallel_state + from megatron.core.distributed import DistributedDataParallel as MegatronDDP + from packaging import version + MEGATRON_AVAILABLE = True + mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') +except ImportError: + MEGATRON_AVAILABLE = False + mcore_013 = False + + +def check_megatron_available(): + """Check if Megatron-Core is available.""" + if not MEGATRON_AVAILABLE: + raise ImportError( + "Megatron-Core is not installed. Please install it with: " + "pip install megatron-core" + ) + + +class MegatronStrategy(TrainStrategy): + """Strategy for Megatron-Core based distributed training. + + Supports Tensor Parallel (TP), Pipeline Parallel (PP), Context Parallel (CP), + Expert Parallel (EP), and Data Parallel (DP). + + This strategy integrates with twinkle's DeviceMesh to provide a unified + interface for distributed training configuration. + """ + + def __init__( + self, + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + context_parallel_size: int = 1, + expert_model_parallel_size: int = 1, + expert_tensor_parallel_size: Optional[int] = None, + virtual_pipeline_model_parallel_size: Optional[int] = None, + sequence_parallel: bool = False, + use_distributed_optimizer: bool = True, + mixed_precision: Literal['no', 'fp16', 'bf16'] = 'bf16', + params_dtype: Optional[str] = None, + device_mesh: Optional['DeviceMesh'] = None, + megatron_args: Optional[Dict[str, Any]] = None, + ): + """Initialize MegatronStrategy. + + Args: + tensor_model_parallel_size: Degree of tensor model parallelism. + pipeline_model_parallel_size: Degree of pipeline model parallelism. + context_parallel_size: Degree of context parallelism. + expert_model_parallel_size: Degree of expert model parallelism for MoE. + expert_tensor_parallel_size: Degree of expert tensor parallelism. + virtual_pipeline_model_parallel_size: Virtual pipeline model parallel size. + sequence_parallel: Enable sequence parallelism. + use_distributed_optimizer: Use Megatron's distributed optimizer. + mixed_precision: Mixed precision mode. + params_dtype: Parameter dtype string (e.g., 'bf16', 'fp32'). + device_mesh: Twinkle DeviceMesh for distributed configuration. + megatron_args: Additional Megatron arguments. + """ + check_megatron_available() + + # If device_mesh is provided, extract parallel sizes from it + if device_mesh is not None: + tensor_model_parallel_size = self._get_dim_from_mesh(device_mesh, 'tp', tensor_model_parallel_size) + pipeline_model_parallel_size = self._get_dim_from_mesh(device_mesh, 'pp', pipeline_model_parallel_size) + context_parallel_size = self._get_dim_from_mesh(device_mesh, 'cp', context_parallel_size) + expert_model_parallel_size = self._get_dim_from_mesh(device_mesh, 'ep', expert_model_parallel_size) + + self.tp_size = tensor_model_parallel_size + self.pp_size = pipeline_model_parallel_size + self.cp_size = context_parallel_size + self.ep_size = expert_model_parallel_size + self.etp_size = expert_tensor_parallel_size or tensor_model_parallel_size + self.vp_size = virtual_pipeline_model_parallel_size + self.sequence_parallel = sequence_parallel + self.use_distributed_optimizer = use_distributed_optimizer + self.mixed_precision = mixed_precision + self.params_dtype = params_dtype + self.device_mesh = device_mesh + self.megatron_args = megatron_args or {} + + self._initialized = False + self._parallel_state = None + + @staticmethod + def _get_dim_from_mesh(device_mesh: 'DeviceMesh', dim_name: str, default: int) -> int: + """Get dimension size from device mesh. + + Args: + device_mesh: The device mesh. + dim_name: Name of the dimension. + default: Default value if dimension not found. + + Returns: + Dimension size. + """ + if device_mesh is None: + return default + if hasattr(device_mesh, 'has_dim') and device_mesh.has_dim(dim_name): + return device_mesh.get_dim_size(dim_name) + return default + + @classmethod + def from_device_mesh( + cls, + device_mesh: 'DeviceMesh', + sequence_parallel: bool = False, + use_distributed_optimizer: bool = True, + mixed_precision: Literal['no', 'fp16', 'bf16'] = 'bf16', + **kwargs, + ) -> 'MegatronStrategy': + """Create MegatronStrategy from twinkle DeviceMesh. + + Args: + device_mesh: Twinkle DeviceMesh with dimension names like 'tp', 'pp', 'cp', 'ep', 'dp'. + sequence_parallel: Enable sequence parallelism. + use_distributed_optimizer: Use Megatron's distributed optimizer. + mixed_precision: Mixed precision mode. + **kwargs: Additional arguments. + + Returns: + MegatronStrategy instance. + """ + return cls( + device_mesh=device_mesh, + sequence_parallel=sequence_parallel, + use_distributed_optimizer=use_distributed_optimizer, + mixed_precision=mixed_precision, + **kwargs, + ) + + def initialize(self, **kwargs) -> None: + """Initialize Megatron parallel state. + + Should be called after distributed process group is initialized. + This sets up all the parallel groups for TP/PP/CP/EP/DP. + """ + if self._initialized: + return + + if not dist.is_initialized(): + # Initialize torch distributed if not already done + dist.init_process_group(backend='nccl') + + world_size = dist.get_world_size() + + # Validate parallel configuration + total_model_parallel = self.tp_size * self.pp_size * self.cp_size + if world_size % total_model_parallel != 0: + raise ValueError( + f"World size ({world_size}) must be divisible by " + f"tp_size * pp_size * cp_size ({total_model_parallel})" + ) + + # Initialize Megatron parallel state + init_kwargs = { + 'tensor_model_parallel_size': self.tp_size, + 'pipeline_model_parallel_size': self.pp_size, + 'context_parallel_size': self.cp_size, + } + + if self.vp_size is not None: + init_kwargs['virtual_pipeline_model_parallel_size'] = self.vp_size + + # Handle MoE parallelism + if self.ep_size > 1: + init_kwargs['expert_model_parallel_size'] = self.ep_size + if mcore_013: + init_kwargs['expert_tensor_parallel_size'] = self.etp_size + + parallel_state.initialize_model_parallel(**init_kwargs) + + self._parallel_state = parallel_state + self._initialized = True + + # Set CUDA device + local_rank = dist.get_rank() % torch.cuda.device_count() + torch.cuda.set_device(local_rank) + + def destroy(self) -> None: + """Destroy parallel state and clean up resources.""" + if self._initialized and self._parallel_state is not None: + self._parallel_state.destroy_model_parallel() + self._initialized = False + + @property + def tp_rank(self) -> int: + """Get tensor parallel rank.""" + if not self._initialized: + return 0 + return self._parallel_state.get_tensor_model_parallel_rank() + + @property + def pp_rank(self) -> int: + """Get pipeline parallel rank.""" + if not self._initialized: + return 0 + return self._parallel_state.get_pipeline_model_parallel_rank() + + @property + def dp_rank(self) -> int: + """Get data parallel rank.""" + if not self._initialized: + return 0 + return self._parallel_state.get_data_parallel_rank() + + @property + def cp_rank(self) -> int: + """Get context parallel rank.""" + if not self._initialized: + return 0 + return self._parallel_state.get_context_parallel_rank() + + @property + def ep_rank(self) -> int: + """Get expert parallel rank.""" + if not self._initialized: + return 0 + return self._parallel_state.get_expert_model_parallel_rank() + + @property + def dp_size(self) -> int: + """Get data parallel size.""" + if not self._initialized: + world_size = dist.get_world_size() if dist.is_initialized() else 1 + return world_size // (self.tp_size * self.pp_size * self.cp_size) + return self._parallel_state.get_data_parallel_world_size() + + @property + def tp_group(self): + """Get tensor parallel process group.""" + if not self._initialized: + return None + return self._parallel_state.get_tensor_model_parallel_group() + + @property + def dp_group(self): + """Get data parallel process group.""" + if not self._initialized: + return None + return self._parallel_state.get_data_parallel_group() + + @property + def pp_group(self): + """Get pipeline parallel process group.""" + if not self._initialized: + return None + return self._parallel_state.get_pipeline_model_parallel_group() + + @property + def cp_group(self): + """Get context parallel process group.""" + if not self._initialized: + return None + return self._parallel_state.get_context_parallel_group() + + @property + def ep_group(self): + """Get expert parallel process group.""" + if not self._initialized: + return None + return self._parallel_state.get_expert_model_parallel_group() + + def is_pipeline_first_stage(self) -> bool: + """Check if current rank is pipeline first stage.""" + if not self._initialized: + return True + return self._parallel_state.is_pipeline_first_stage() + + def is_pipeline_last_stage(self) -> bool: + """Check if current rank is pipeline last stage.""" + if not self._initialized: + return True + return self._parallel_state.is_pipeline_last_stage() + + def is_data_parallel_main_rank(self) -> bool: + """Check if current rank is the main rank in data parallel group.""" + if not self._initialized: + return True + return self.dp_rank == 0 + + def get_params_dtype(self) -> torch.dtype: + """Get parameter dtype based on configuration. + + Returns: + PyTorch dtype for model parameters. + """ + if self.params_dtype is not None: + dtype_map = { + 'fp32': torch.float32, + 'fp16': torch.float16, + 'bf16': torch.bfloat16, + } + return dtype_map.get(self.params_dtype, torch.bfloat16) + + if self.mixed_precision == 'bf16': + return torch.bfloat16 + elif self.mixed_precision == 'fp16': + return torch.float16 + return torch.float32 + + def wrap_model( + self, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + ) -> Tuple[nn.Module, Optional[torch.optim.Optimizer]]: + """Wrap model with distributed wrapper for data parallelism. + + In Megatron, TP/PP/CP/EP parallelism is already handled during model creation + (via TransformerConfig and parallel_state). This method only handles Data + Parallel (DP) wrapping, which synchronizes gradients across DP ranks. + + For PEFT/LoRA models: + - We skip DDP wrapping to avoid compatibility issues + - Gradients are synchronized manually via all_reduce_gradients() + - This is more flexible and works reliably with dynamically added LoRA modules + + For full model training (non-PEFT): + - Consider using Megatron's native training.setup_model_and_optimizer() + - Or use Megatron DDP with proper TransformerConfig + + Args: + model: The Megatron model to wrap (already parallelized via TP/PP). + optimizer: Optional optimizer (not wrapped here; use DistributedOptimizer separately if needed). + + Returns: + Tuple of (wrapped_model, wrapped_optimizer). + For PEFT models, wrapped_model is the original model (no DDP wrapper). + """ + if not self._initialized: + self.initialize() + + # Check if this is a PEFT/LoRA model + is_peft_model = hasattr(model, 'peft_config') or hasattr(model, 'base_model') + + if is_peft_model: + # For PEFT models, skip DDP wrapping entirely. + # Reasons: + # 1. PEFT models have dynamically added modules that may cause issues with DDP + # 2. LoRA typically has very few trainable parameters, so manual gradient sync is efficient + # 3. Megatron DDP requires TransformerConfig which may not be accessible after PEFT wrapping + # 4. PyTorch DDP has device placement issues when model uses CPU initialization + # + # Instead, gradients should be synchronized manually using all_reduce_gradients() + # after backward() and before optimizer.step(). + return model, optimizer + + # For non-PEFT models, we can use Megatron DDP or PyTorch DDP + dp_group = self.dp_group + if dp_group is None or dist.get_world_size(dp_group) <= 1: + # No DP needed (single GPU or no DP group) + return model, optimizer + + # Get model config for Megatron DDP + config = getattr(model, 'config', None) + + # Check if model is on GPU (required for DDP) + model_device = next(model.parameters()).device + if model_device.type == 'cpu': + # Model is on CPU, need to move to GPU first + # This happens when use_cpu_initialization=True + local_rank = dist.get_rank() % torch.cuda.device_count() + model = model.to(f'cuda:{local_rank}') + + if config is not None and hasattr(config, 'tensor_model_parallel_size'): + # Model has TransformerConfig, use Megatron DDP + try: + from megatron.core.distributed import DistributedDataParallelConfig + ddp_config = DistributedDataParallelConfig( + grad_reduce_in_fp32=True, + overlap_grad_reduce=False, + use_distributed_optimizer=self.use_distributed_optimizer, + check_for_nan_in_grad=False, + bucket_size=None, # No bucketing for simpler gradient sync + ) + wrapped_model = MegatronDDP( + config=config, + ddp_config=ddp_config, + module=model, + ) + return wrapped_model, optimizer + except (ImportError, TypeError) as e: + # Fallback to PyTorch DDP if Megatron DDP fails + pass + + # Fallback: PyTorch DDP for models without TransformerConfig + from torch.nn.parallel import DistributedDataParallel as TorchDDP + wrapped_model = TorchDDP( + model, + process_group=dp_group, + # Note: Don't use device_ids for multi-GPU models or when model spans devices + ) + + return wrapped_model, optimizer + + def unwrap_model(self, model: nn.Module) -> nn.Module: + """Unwrap the distributed model to get the base model. + + Args: + model: The wrapped model. + + Returns: + The unwrapped base model. + """ + if isinstance(model, MegatronDDP): + return model.module + + from torch.nn.parallel import DistributedDataParallel as TorchDDP + if isinstance(model, TorchDDP): + return model.module + + return model + + def get_model_config( + self, + hidden_size: int, + num_attention_heads: int, + num_layers: int, + ffn_hidden_size: Optional[int] = None, + num_query_groups: Optional[int] = None, + vocab_size: int = 32000, + max_position_embeddings: int = 4096, + num_experts: Optional[int] = None, + moe_router_topk: int = 2, + **kwargs, + ): + """Create a Megatron TransformerConfig. + + Args: + hidden_size: Hidden dimension size. + num_attention_heads: Number of attention heads. + num_layers: Number of transformer layers. + ffn_hidden_size: FFN hidden size (default: 4 * hidden_size). + num_query_groups: Number of KV heads for GQA. + vocab_size: Vocabulary size. + max_position_embeddings: Maximum sequence length. + num_experts: Number of MoE experts. + moe_router_topk: Top-k for MoE routing. + **kwargs: Additional config arguments. + + Returns: + Megatron TransformerConfig. + """ + from megatron.core.transformer import TransformerConfig + + config = TransformerConfig( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_query_groups=num_query_groups or num_attention_heads, + ffn_hidden_size=ffn_hidden_size or 4 * hidden_size, + use_cpu_initialization=True, + params_dtype=self.get_params_dtype(), + tensor_model_parallel_size=self.tp_size, + pipeline_model_parallel_size=self.pp_size, + context_parallel_size=self.cp_size, + expert_model_parallel_size=self.ep_size, + sequence_parallel=self.sequence_parallel, + num_moe_experts=num_experts, + moe_router_topk=moe_router_topk, + **kwargs, + ) + + return config + + def sync_gradients(self, model: Optional[nn.Module] = None) -> None: + """Synchronize gradients across data parallel group. + + For DDP-wrapped models, gradients are synchronized automatically. + For non-DDP models (e.g., PEFT models), this performs manual all-reduce. + + Args: + model: Optional model to sync gradients for. If None, only barrier. + """ + if not self._initialized: + return + + dp_group = self.dp_group + if dp_group is None: + return + + dp_size = dist.get_world_size(dp_group) + if dp_size <= 1: + return + + if model is not None: + # Manual gradient synchronization for non-DDP models (e.g., PEFT) + self.all_reduce_gradients(model) + else: + # Just barrier for DDP models + dist.barrier(dp_group) + + def all_reduce_gradients(self, model: nn.Module) -> None: + """All-reduce gradients of trainable parameters across data parallel group. + + This is used for PEFT/LoRA models that are not wrapped with DDP. + Gradients are averaged across all DP ranks. + + Args: + model: The model whose gradients to synchronize. + """ + if not self._initialized: + return + + dp_group = self.dp_group + if dp_group is None: + return + + dp_size = dist.get_world_size(dp_group) + if dp_size <= 1: + return + + # Collect gradients from trainable parameters + grads = [] + for param in model.parameters(): + if param.requires_grad and param.grad is not None: + grads.append(param.grad.data) + + if not grads: + return + + # Flatten all gradients into a single tensor for efficient communication + # This reduces the number of all-reduce operations + flat_grads = torch.cat([g.contiguous().view(-1) for g in grads]) + + # All-reduce and average + dist.all_reduce(flat_grads, op=dist.ReduceOp.SUM, group=dp_group) + flat_grads.div_(dp_size) + + # Unflatten back to original gradient tensors + offset = 0 + for grad in grads: + numel = grad.numel() + grad.copy_(flat_grads[offset:offset + numel].view_as(grad)) + offset += numel + + def all_reduce( + self, + tensor: torch.Tensor, + op: dist.ReduceOp = dist.ReduceOp.SUM, + group: Optional[dist.ProcessGroup] = None, + ) -> torch.Tensor: + """All-reduce tensor across specified group. + + Args: + tensor: Input tensor. + op: Reduce operation. + group: Process group (defaults to data parallel group). + + Returns: + Reduced tensor. + """ + if not self._initialized: + return tensor + + if group is None: + group = self.dp_group + + if group is not None: + dist.all_reduce(tensor, op=op, group=group) + + return tensor + + def broadcast( + self, + tensor: torch.Tensor, + src: int = 0, + group: Optional[dist.ProcessGroup] = None, + ) -> torch.Tensor: + """Broadcast tensor from source rank. + + Args: + tensor: Input tensor. + src: Source rank. + group: Process group (defaults to data parallel group). + + Returns: + Broadcasted tensor. + """ + if not self._initialized: + return tensor + + if group is None: + group = self.dp_group + + if group is not None: + dist.broadcast(tensor, src=src, group=group) + + return tensor + + def get_parallel_info(self) -> Dict[str, Any]: + """Get parallelism configuration information. + + Returns: + Dict with parallel configuration details. + """ + return { + 'tp_size': self.tp_size, + 'pp_size': self.pp_size, + 'cp_size': self.cp_size, + 'ep_size': self.ep_size, + 'etp_size': self.etp_size, + 'vp_size': self.vp_size, + 'dp_size': self.dp_size, + 'sequence_parallel': self.sequence_parallel, + 'use_distributed_optimizer': self.use_distributed_optimizer, + 'mixed_precision': self.mixed_precision, + 'tp_rank': self.tp_rank, + 'pp_rank': self.pp_rank, + 'dp_rank': self.dp_rank, + 'cp_rank': self.cp_rank, + 'ep_rank': self.ep_rank, + } + + def __repr__(self) -> str: + return ( + f"MegatronStrategy(tp={self.tp_size}, pp={self.pp_size}, " + f"cp={self.cp_size}, ep={self.ep_size}, dp={self.dp_size}, " + f"sequence_parallel={self.sequence_parallel})" + ) diff --git a/src/twinkle/model/transformers.py b/src/twinkle/model/transformers.py index 3b399b15..2423529d 100644 --- a/src/twinkle/model/transformers.py +++ b/src/twinkle/model/transformers.py @@ -26,6 +26,7 @@ from twinkle.utils.plugin import Plugin from .base import TwinkleModel from .strategy import AccelerateStrategy +from twinkle.utils import torch_util @dataclass @@ -231,7 +232,10 @@ def clip_grad_norm(self, max_grad_norm: float=1.0, norm_type=2, **kwargs): scaler.unscale_(optimizer) parameters = self._get_trainable_parameters(adapter_name).values() - return torch.nn.utils.clip_grad_norm_(parameters, max_grad_norm, norm_type=norm_type).detach().cpu().numpy() + grad_norm = torch.nn.utils.clip_grad_norm_(parameters, max_grad_norm, norm_type=norm_type) + # Convert DTensor to local tensor for FSDP2 compatibility + grad_norm = torch_util.to_local_tensor(grad_norm) + return grad_norm.detach().cpu().numpy() @remote_function() def step(self, **kwargs): diff --git a/src/twinkle/processor/base.py b/src/twinkle/processor/base.py index 3bf159f3..157c62a8 100644 --- a/src/twinkle/processor/base.py +++ b/src/twinkle/processor/base.py @@ -60,8 +60,12 @@ def _inner_collate_fn(self, batch): values = [item[key] for item in batch] if isinstance(values[0], np.ndarray): - values = [torch.from_numpy(v) for v in values] - result[key] = InputProcessor._pad_sequence(values, self.padding_map[key], self.padding_side) + # Skip string arrays - they can't be converted to tensors + if values[0].dtype.kind in ('U', 'S', 'O'): # Unicode, byte string, or object + result[key] = values + else: + values = [torch.from_numpy(v) for v in values] + result[key] = InputProcessor._pad_sequence(values, self.padding_map[key], self.padding_side) elif isinstance(values[0], torch.Tensor): result[key] = InputProcessor._pad_sequence(values, self.padding_map[key], self.padding_side) else: diff --git a/src/twinkle/utils/framework.py b/src/twinkle/utils/framework.py index c2a017ce..6f52529e 100644 --- a/src/twinkle/utils/framework.py +++ b/src/twinkle/utils/framework.py @@ -166,3 +166,22 @@ def seed_everything(seed: Optional[int] = 42, deterministic: bool = False): if Torch.is_npu_available(): os.environ["ASCEND_LAUNCH_BLOCKING"] = "1" os.environ["HCCL_DETERMINISTIC"] = "1" + + @staticmethod + def to_local_tensor(tensor: 'torch.Tensor') -> 'torch.Tensor': + """Convert DTensor to local tensor if needed. + + Args: + tensor: A torch.Tensor or DTensor instance. + + Returns: + A local torch.Tensor. + """ + import torch + if hasattr(tensor, 'full_tensor'): + # DTensor from torch.distributed.tensor + return tensor.full_tensor() + elif hasattr(tensor, 'to_local'): + # Alternative DTensor API + return tensor.to_local() + return tensor diff --git a/twinkle b/twinkle deleted file mode 120000 index 5a08ecd1..00000000 --- a/twinkle +++ /dev/null @@ -1 +0,0 @@ -src/twinkle \ No newline at end of file From 05172330677df33be64fc73ab59f282e36ab1e2e Mon Sep 17 00:00:00 2001 From: root Date: Mon, 12 Jan 2026 19:26:41 +0800 Subject: [PATCH 02/22] local done --- cookbook/megatron/lora.py | 299 ++++---- .../loss/vocab_parallel_cross_entropy.py | 31 +- src/twinkle/megatron/__init__.py | 1 - src/twinkle/megatron/model/__init__.py | 1 - src/twinkle/megatron/model/bridge.py | 19 +- src/twinkle/megatron/model/initializer.py | 30 +- src/twinkle/megatron/model/qwen3.py | 3 +- src/twinkle/megatron/model/swift_bridge.py | 673 ++++++++++++++++++ src/twinkle/megatron/tuners/lora.py | 2 - src/twinkle/megatron/utils.py | 64 +- src/twinkle/model/megatron.py | 318 ++++++++- src/twinkle/model/strategy/megatron.py | 216 ++++-- 12 files changed, 1372 insertions(+), 285 deletions(-) create mode 100644 src/twinkle/megatron/model/swift_bridge.py diff --git a/cookbook/megatron/lora.py b/cookbook/megatron/lora.py index 589b1448..907828e5 100644 --- a/cookbook/megatron/lora.py +++ b/cookbook/megatron/lora.py @@ -1,27 +1,40 @@ # Copyright (c) twinkle authors. All rights reserved. -"""Megatron-Core LoRA training example. +"""Megatron-Core LoRA training example with full 4D parallelism. This example demonstrates LoRA fine-tuning using Megatron-Core backend. -Supports both local (DDP) and Ray distributed modes. +Supports Tensor Parallel (TP), Pipeline Parallel (PP), Context Parallel (CP), +and Data Parallel (DP). DP is automatically calculated from WORLD_SIZE. -Usage (local mode with 4 GPUs): - torchrun --nproc_per_node=4 cookbook/megatron/lora.py --mode local +The script uses Megatron's get_forward_backward_func() for unified pipeline +scheduling, ensuring proper multi-tenant isolation through process groups. -Usage (Ray mode): - python cookbook/megatron/lora.py --mode ray +TODO: Add Expert Parallel (EP) support for MoE models. + +Usage (8 GPUs with CP2 PP2 TP2, DP auto-calculated as 1): + torchrun --nproc_per_node=8 cookbook/megatron/lora.py \ + --tp_size 2 --pp_size 2 --cp_size 2 + +Usage (4 GPUs with TP2, DP auto-calculated as 2): + torchrun --nproc_per_node=4 cookbook/megatron/lora.py --tp_size 2 + +Usage (single GPU for debugging): + torchrun --nproc_per_node=1 cookbook/megatron/lora.py + +Note: WORLD_SIZE is automatically detected from torchrun, no need to specify it twice. """ import argparse +import os import numpy as np from peft import LoraConfig from torch.optim import AdamW -from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import LinearLR import twinkle from twinkle import get_device_placement, get_logger, DeviceMesh, DeviceGroup, Platform from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta -from twinkle.loss import CrossEntropyLoss, MegatronCrossEntropyLoss +from twinkle.loss import MegatronCrossEntropyLoss from twinkle.model import MegatronModel from twinkle.processor import InputProcessor @@ -29,73 +42,79 @@ def parse_args(): - parser = argparse.ArgumentParser(description='Megatron LoRA Training') + parser = argparse.ArgumentParser(description='Megatron LoRA Training with 4D Parallelism') # Mode selection - parser.add_argument('--mode', type=str, default='ray', + parser.add_argument('--mode', type=str, default='local', choices=['local', 'ray'], - help='Distributed mode: local (DDP) or ray') - - # Model arguments - parser.add_argument('--model_name', type=str, default='ms://Qwen/Qwen2.5-7B-Instruct', - help='HuggingFace model name or path') - parser.add_argument('--output_dir', type=str, default='./output/megatron_lora', - help='Output directory for checkpoints') + help='Distributed mode: local (torchrun) or ray') - # Parallelism arguments + # Number of GPUs parser.add_argument('--nproc_per_node', type=int, default=4, - help='Number of processes per node') - parser.add_argument('--tp_size', type=int, default=2, - help='Tensor parallel size') - parser.add_argument('--dp_size', type=int, default=2, - help='Data parallel size') - parser.add_argument('--sequence_parallel', action='store_true', - help='Enable sequence parallelism') - parser.add_argument('--mixed_precision', type=str, default='bf16', - choices=['no', 'fp16', 'bf16'], - help='Mixed precision mode') - - # LoRA arguments - parser.add_argument('--lora_rank', type=int, default=8, - help='LoRA rank') - parser.add_argument('--lora_alpha', type=int, default=32, - help='LoRA alpha') - parser.add_argument('--lora_dropout', type=float, default=0.05, - help='LoRA dropout') - parser.add_argument('--target_modules', type=str, default='all-linear', - help='Target modules for LoRA') - - # Training arguments - parser.add_argument('--batch_size', type=int, default=4, - help='Batch size per GPU') - parser.add_argument('--gradient_accumulation_steps', type=int, default=16, - help='Gradient accumulation steps') - parser.add_argument('--learning_rate', type=float, default=1e-4, - help='Learning rate') - parser.add_argument('--max_grad_norm', type=float, default=1.0, - help='Maximum gradient norm for clipping') - parser.add_argument('--max_steps', type=int, default=1000, - help='Maximum training steps') - parser.add_argument('--save_steps', type=int, default=50, - help='Checkpoint save interval') - - # Dataset arguments - parser.add_argument('--dataset', type=str, default='ms://modelscope/competition_math', - help='Dataset name') - - return parser.parse_args() + help='Total number of GPUs') + + # 4D Parallelism configuration + # Total GPUs = DP * CP * PP * TP (DP is auto-calculated) + # TODO: Add EP (Expert Parallel) for MoE models + parser.add_argument('--tp_size', type=int, default=1, + help='Tensor Parallel size (splits model layers horizontally)') + parser.add_argument('--pp_size', type=int, default=1, + help='Pipeline Parallel size (splits model layers vertically)') + parser.add_argument('--cp_size', type=int, default=1, + help='Context Parallel size (splits sequence across GPUs)') + # Note: DP size is automatically calculated as: WORLD_SIZE / (TP * PP * CP) + + # Sequence parallel (usually enabled with TP > 1) + parser.add_argument('--sequence_parallel', action='store_true', default=False, + help='Enable sequence parallelism (recommended when TP > 1)') + + # Max steps for quick testing + parser.add_argument('--max_steps', type=int, default=None, + help='Maximum training steps (for testing)') + + args = parser.parse_args() + + # Auto-detect world size from environment (set by torchrun) + world_size = int(os.environ.get('WORLD_SIZE', '1')) + args.world_size = world_size + + # Auto-calculate DP size from total GPUs and model parallel sizes + model_parallel_size = args.tp_size * args.pp_size * args.cp_size + if world_size % model_parallel_size != 0: + raise ValueError( + f'WORLD_SIZE ({world_size}) must be divisible by ' + f'TP({args.tp_size}) * PP({args.pp_size}) * CP({args.cp_size}) = {model_parallel_size}' + ) + args.dp_size = world_size // model_parallel_size + + logger.info(f'4D Parallelism config: DP={args.dp_size} (auto), TP={args.tp_size}, ' + f'PP={args.pp_size}, CP={args.cp_size}, Total GPUs={world_size}') + + return args def create_device_mesh(args) -> DeviceMesh: - """Create device mesh for Megatron parallelism.""" - # For Megatron: mesh shape is (dp, tp) - # dp_size * tp_size = nproc_per_node - mesh = np.arange(args.nproc_per_node).reshape(args.dp_size, args.tp_size) + """Create device mesh for Megatron 4D parallelism. + + Megatron uses the following parallelism hierarchy (outer to inner): + - Data Parallel (DP): Replicates model, splits data (auto-calculated) + - Context Parallel (CP): Splits sequence across GPUs + - Pipeline Parallel (PP): Splits layers across stages + - Tensor Parallel (TP): Splits layers horizontally + + TODO: Add Expert Parallel (EP) dimension for MoE models. + + Mesh shape: (dp, cp, pp, tp) + """ + total_gpus = args.world_size + + # Create mesh with shape (dp, cp, pp, tp) + mesh = np.arange(total_gpus).reshape(args.dp_size, args.cp_size, args.pp_size, args.tp_size) device_mesh = DeviceMesh( device_type='cuda', mesh=mesh, - mesh_dim_names=('dp', 'tp'), + mesh_dim_names=('dp', 'cp', 'pp', 'tp'), ) return device_mesh @@ -105,97 +124,102 @@ def create_device_group(args): device_group = [ DeviceGroup( name='model', - ranks=list(range(args.nproc_per_node)), + ranks=list(range(args.world_size)), device_type=Platform.get_platform().device_prefix(), ) ] return device_group -def create_dataset(args): +def create_dataset(): """Create and preprocess dataset.""" - dataset = Dataset(dataset_meta=DatasetMeta(args.dataset)) - dataset.set_template('Qwen3Template', model_id=args.model_name) + dataset = Dataset(dataset_meta=DatasetMeta('ms://modelscope/competition_math')) + dataset.set_template('Qwen3Template', model_id='ms://Qwen/Qwen2.5-7B-Instruct') dataset.map('CompetitionMathProcessor') - dataset.encode(batched=True) + # IMPORTANT: Use load_from_cache_file=False to avoid stale cache with incorrect labels + dataset.encode(batched=True, load_from_cache_file=False) return dataset def train(args): - """Main training function.""" + """Main training function with 4D parallelism support.""" # Create dataloader - dataloader = DataLoader( - dataset=lambda: create_dataset(args), - batch_size=args.batch_size, - ) + dataloader = DataLoader(dataset=create_dataset, batch_size=8) - # Create Megatron model + # Create Megatron model with 4D parallelism + # TODO: Add expert_model_parallel_size for MoE models model = MegatronModel( - pretrained_model_name_or_path=args.model_name, + pretrained_model_name_or_path='ms://Qwen/Qwen2.5-7B-Instruct', tensor_model_parallel_size=args.tp_size, + pipeline_model_parallel_size=args.pp_size, + context_parallel_size=args.cp_size, sequence_parallel=args.sequence_parallel, - mixed_precision=args.mixed_precision, + mixed_precision='bf16', ) + # Set template, processor, loss on DEFAULT adapter FIRST + # These will be copied when adding LoRA adapter + model.set_template('Qwen3Template') + model.set_processor(InputProcessor, padding_side='right') + model.set_loss(MegatronCrossEntropyLoss) + # Configure LoRA adapter lora_config = LoraConfig( - r=args.lora_rank, - lora_alpha=args.lora_alpha, - lora_dropout=args.lora_dropout, - target_modules=args.target_modules, - ) - - model.add_adapter_to_model( - 'default', - lora_config, - gradient_accumulation_steps=args.gradient_accumulation_steps, + target_modules='all-linear' ) - # Set template and processor - model.set_template('Qwen3Template') - model.set_processor(InputProcessor, padding_side='right') + # Add LoRA adapter - template, processor, loss_instance will be copied from default + model.add_adapter_to_model('lora', lora_config, gradient_accumulation_steps=16) - # Set loss, optimizer, scheduler - model.set_loss(MegatronCrossEntropyLoss) - model.set_optimizer(AdamW, lr=args.learning_rate, weight_decay=0.01) - model.set_lr_scheduler(CosineAnnealingLR, T_max=args.max_steps) + # Set optimizer and scheduler for LoRA adapter (must be after add_adapter_to_model) + model.set_optimizer(AdamW, lr=1e-4, adapter_name='lora') + model.set_lr_scheduler(LinearLR, adapter_name='lora') # Print training configuration logger.info(get_device_placement()) - logger.info(model.get_train_configs()) + logger.info(model.get_train_configs(adapter_name='lora')) # Training loop - global_step = 0 + gradient_accumulation_steps = 16 + optimizer_step = 0 + max_steps = args.max_steps + for step, batch in enumerate(dataloader): - if global_step >= args.max_steps: - break - - # Forward-backward pass - output = model.forward_backward(inputs=batch) + output = model.forward_backward(inputs=batch, adapter_name='lora') - # Log loss at gradient accumulation boundary - if step % args.gradient_accumulation_steps == 0: - logger.info(f'Step {global_step}, Loss: {output}') - global_step += 1 - - # Gradient clipping and optimizer step - model.clip_grad_norm(args.max_grad_norm) - model.step() - model.zero_grad() - model.lr_step() - - # Save checkpoint - if global_step > 0 and global_step % args.save_steps == 0: - model.save(f'{args.output_dir}/checkpoint-{global_step}') + # Only perform optimizer step at gradient accumulation boundary + if (step + 1) % gradient_accumulation_steps == 0: + optimizer_step = (step + 1) // gradient_accumulation_steps + + # Log loss + logger.info(f'Current is step {optimizer_step}, loss: {output}') + + # Gradient clipping and optimizer step + model.clip_grad_norm(1.0, adapter_name='lora') + model.step(adapter_name='lora') + model.zero_grad(adapter_name='lora') + model.lr_step(adapter_name='lora') + + # Save checkpoint every 100 optimizer steps + if optimizer_step % 100 == 0: + model.save('./output/megatron_lora', adapter_name='lora') + + # Check max_steps for early stopping (for testing) + if max_steps is not None and optimizer_step >= max_steps: + logger.info(f'Reached max_steps ({max_steps}), stopping training.') + break - # Save final model - model.save(args.output_dir) - logger.info(f'Model saved to {args.output_dir}') + # Save final checkpoint + logger.info(f'Training completed! Final step: {optimizer_step}') + model.save('./output/megatron_lora', adapter_name='lora') def main(): args = parse_args() + # Set TWINKLE_MODE environment variable for strategy to detect + os.environ['TWINKLE_MODE'] = args.mode + # Create device mesh and group device_mesh = create_device_mesh(args) device_group = create_device_group(args) @@ -203,14 +227,49 @@ def main(): # Initialize twinkle with specified mode twinkle.initialize( mode=args.mode, - nproc_per_node=args.nproc_per_node, + nproc_per_node=args.world_size, groups=device_group, global_device_mesh=device_mesh, lazy_collect=False, ) - # Start training - train(args) + try: + # Start training + train(args) + finally: + # Clean up distributed process groups + cleanup_distributed() + + +def cleanup_distributed(): + """Clean up all distributed process groups.""" + import torch + import torch.distributed as dist + + # Synchronize all processes before cleanup + if dist.is_initialized(): + try: + # Use barrier with timeout to prevent hanging + dist.barrier() + except Exception as e: + logger.warning(f'Barrier failed during cleanup: {e}') + + try: + dist.destroy_process_group() + except Exception as e: + logger.warning(f'Failed to destroy process group: {e}') + + # Also clean up Megatron's parallel state if initialized + try: + from megatron.core import parallel_state as mpu + if mpu.is_initialized(): + mpu.destroy_model_parallel() + except Exception: + pass + + # Clear CUDA cache + if torch.cuda.is_available(): + torch.cuda.empty_cache() if __name__ == '__main__': diff --git a/src/twinkle/loss/vocab_parallel_cross_entropy.py b/src/twinkle/loss/vocab_parallel_cross_entropy.py index 16dc03aa..5a47bbfb 100644 --- a/src/twinkle/loss/vocab_parallel_cross_entropy.py +++ b/src/twinkle/loss/vocab_parallel_cross_entropy.py @@ -27,6 +27,9 @@ class VocabParallelCrossEntropyLoss(Loss): This loss uses Megatron's vocab_parallel_cross_entropy which correctly handles the distributed computation. + NOTE: Labels are expected to be pre-shifted by the template (using np.roll). + This loss does NOT perform additional shifting. + Fallback: When Megatron is not available or TP=1, uses standard CrossEntropyLoss. """ @@ -34,44 +37,42 @@ def __call__(self, inputs, outputs, **kwargs): logits = outputs['logits'] labels = inputs['labels'] - # Get dimensions - # logits: [batch, seq, vocab] or [batch, seq, partition_vocab] - # labels: [batch, seq] + shift_logits = logits[:, :-1, :].contiguous() + shift_labels = labels[:, :-1].contiguous() if not MEGATRON_AVAILABLE: # Fallback to standard loss - logits_2d = logits.view(-1, logits.shape[-1]) - labels_1d = labels.view(-1) + logits_2d = shift_logits.view(-1, shift_logits.shape[-1]) + labels_1d = shift_labels.view(-1) return F.cross_entropy(logits_2d, labels_1d, ignore_index=-100) tp_size = mpu.get_tensor_model_parallel_world_size() if tp_size == 1: # No TP, use standard cross entropy - logits_2d = logits.view(-1, logits.shape[-1]) - labels_1d = labels.view(-1) + logits_2d = shift_logits.view(-1, shift_logits.shape[-1]) + labels_1d = shift_labels.view(-1) return F.cross_entropy(logits_2d, labels_1d, ignore_index=-100) # Use Megatron's vocab-parallel cross entropy # Megatron expects [seq, batch, vocab] format for logits # and [seq, batch] for labels - # Transpose logits: [batch, seq, vocab] -> [seq, batch, vocab] - logits_sbv = logits.transpose(0, 1).contiguous() + # Transpose logits: [batch, seq-1, vocab] -> [seq-1, batch, vocab] + logits_sbv = shift_logits.transpose(0, 1).contiguous() - # Transpose labels: [batch, seq] -> [seq, batch] - # Must be contiguous for Megatron's view() operations - labels_sb = labels.transpose(0, 1).contiguous() + # Transpose labels: [batch, seq-1] -> [seq-1, batch] + labels_sb = shift_labels.transpose(0, 1).contiguous() # Megatron's vocab_parallel_cross_entropy handles the TP sharding correctly - # It returns per-token loss of shape [seq, batch] + # It returns per-token loss of shape [seq-1, batch] per_token_loss = tensor_parallel.vocab_parallel_cross_entropy(logits_sbv, labels_sb) - # Transpose back: [seq, batch] -> [batch, seq] + # Transpose back: [seq-1, batch] -> [batch, seq-1] per_token_loss = per_token_loss.transpose(0, 1).contiguous() # Apply loss mask (ignore labels == -100) - loss_mask = (labels != -100).float() + loss_mask = (shift_labels != -100).float() # Compute mean loss (only over non-masked positions) loss = (per_token_loss * loss_mask).sum() / loss_mask.sum().clamp(min=1) diff --git a/src/twinkle/megatron/__init__.py b/src/twinkle/megatron/__init__.py index b5ccdf62..b91ea610 100644 --- a/src/twinkle/megatron/__init__.py +++ b/src/twinkle/megatron/__init__.py @@ -2,7 +2,6 @@ """Megatron-Core integration for twinkle training framework. This module provides independent implementation for Megatron support, -without external dependencies on swift's GPTBridge. """ from .tuners import LoraParallelLinear, dispatch_megatron diff --git a/src/twinkle/megatron/model/__init__.py b/src/twinkle/megatron/model/__init__.py index d421c70b..a330f603 100644 --- a/src/twinkle/megatron/model/__init__.py +++ b/src/twinkle/megatron/model/__init__.py @@ -2,7 +2,6 @@ """Megatron model initialization and weight conversion. This module provides independent implementation for weight loading/saving, -without external dependencies on swift. """ from .bridge import ( diff --git a/src/twinkle/megatron/model/bridge.py b/src/twinkle/megatron/model/bridge.py index 33a56356..a75dce7b 100644 --- a/src/twinkle/megatron/model/bridge.py +++ b/src/twinkle/megatron/model/bridge.py @@ -1,10 +1,8 @@ # Copyright (c) twinkle authors. All rights reserved. # GPT Bridge for HuggingFace to Megatron-Core weight conversion. -# This implementation is adapted from ms-swift's GPTBridge. """Weight conversion bridge between HuggingFace and Megatron-Core formats. This module provides independent implementation for weight loading/saving, -adapted from swift's GPTBridge but without external dependencies. Supports: - Qwen2.5 / Qwen3 model families @@ -309,6 +307,18 @@ def from_hf_config( getattr(hf_config, 'moe_topk', 2) shared_expert_size = getattr(hf_config, 'shared_expert_intermediate_size', 0) + # Determine QKV bias setting + # Qwen2 has attention bias by default (hardcoded in transformers), + # but config doesn't have 'attention_bias' field + model_type = getattr(hf_config, 'model_type', 'qwen2') + if hasattr(hf_config, 'attention_bias'): + add_qkv_bias = hf_config.attention_bias + elif model_type in ('qwen2', 'qwen2_5'): + # Qwen2/Qwen2.5 uses bias=True for Q, K, V projections + add_qkv_bias = True + else: + add_qkv_bias = False + return cls( tp_size=tp_size, pp_size=pp_size, @@ -321,7 +331,7 @@ def from_hf_config( vocab_size=vocab_size, padded_vocab_size=padded_vocab_size, intermediate_size=getattr(hf_config, 'intermediate_size', 11008), - add_qkv_bias=getattr(hf_config, 'attention_bias', False), + add_qkv_bias=add_qkv_bias, add_bias_linear=getattr(hf_config, 'mlp_bias', False), qk_layernorm=getattr(hf_config, 'qk_layernorm', False) or \ getattr(hf_config, 'use_qk_norm', False), @@ -329,14 +339,13 @@ def from_hf_config( num_experts=num_experts, num_experts_per_tok=num_experts_per_tok, shared_expert_intermediate_size=shared_expert_size, - model_type=getattr(hf_config, 'model_type', 'qwen2'), + model_type=model_type, ) class TwinkleGPTBridge: """Bridge for converting weights between HuggingFace and Megatron-Core formats. - Adapted from swift's GPTBridge implementation. Supports Qwen2.5 / Qwen3 model families. """ diff --git a/src/twinkle/megatron/model/initializer.py b/src/twinkle/megatron/model/initializer.py index ae0b5346..043486e2 100644 --- a/src/twinkle/megatron/model/initializer.py +++ b/src/twinkle/megatron/model/initializer.py @@ -250,30 +250,42 @@ def load_from_hf( ) -> None: """Load HuggingFace checkpoint into Megatron model. - Uses swift's GPTBridge for maximum compatibility and stability. + TODO: Remove swift dependency. Implement independent bridge logic. Args: model: The Megatron model. - hf_model_path: Path to HuggingFace checkpoint. + hf_model_path: Path to HuggingFace checkpoint or model ID. hf_config: HuggingFace model config. """ - from .bridge import TwinkleBridgeAdapter + import os + + # Resolve model path if it's a model ID (not a local path) + if not os.path.isdir(hf_model_path): + # Try to download using HubOperation (twinkle's hub abstraction) + try: + from twinkle.hub import HubOperation + hf_model_path = HubOperation.download_model(hf_model_path) + except ImportError: + # Fallback to modelscope/huggingface snapshot download + try: + from modelscope import snapshot_download + hf_model_path = snapshot_download(hf_model_path) + except ImportError: + from huggingface_hub import snapshot_download + hf_model_path = snapshot_download(hf_model_path) # Calculate padded vocab size padded_vocab_size = self._pad_vocab_size(hf_config.vocab_size) - # Create bridge adapter - adapter = TwinkleBridgeAdapter( + from .swift_bridge import create_bridge_adapter + adapter = create_bridge_adapter( hf_config=hf_config, tp_size=self.tp_size, pp_size=self.pp_size, ep_size=self.ep_size, - etp_size=self.etp_size, - model_path=hf_model_path, + model_dir=hf_model_path, padded_vocab_size=padded_vocab_size, ) - - # Load weights using swift's bridge adapter.load_weights(model, hf_model_path) diff --git a/src/twinkle/megatron/model/qwen3.py b/src/twinkle/megatron/model/qwen3.py index d49c4c92..b87b79ae 100644 --- a/src/twinkle/megatron/model/qwen3.py +++ b/src/twinkle/megatron/model/qwen3.py @@ -1,8 +1,7 @@ # Copyright (c) twinkle authors. All rights reserved. """Qwen3 model metadata for Megatron-Core. -This module provides metadata for Qwen3 models. The actual weight conversion -is handled by swift's GPTBridge, which already has full Qwen3 support. +This module provides metadata for Qwen3 models. """ from typing import Any, Dict diff --git a/src/twinkle/megatron/model/swift_bridge.py b/src/twinkle/megatron/model/swift_bridge.py new file mode 100644 index 00000000..1a1ad4b9 --- /dev/null +++ b/src/twinkle/megatron/model/swift_bridge.py @@ -0,0 +1,673 @@ +# Copyright (c) twinkle authors. All rights reserved. +"""Bridge module for Megatron-Core weight conversion. + +TODO: Remove dependency on swift package. The bridge logic should be +implemented independently in twinkle to avoid external dependencies. + +This module provides: +1. TwinkleArgs: A dataclass that mimics megatron.training.get_args() return value +2. MegatronBridgeInitializer: Creates Megatron models with proper initialization +""" +import os +from contextlib import contextmanager +from dataclasses import dataclass +from functools import partial +from typing import Any, Optional + +import torch.distributed as dist + +try: + from safetensors.torch import safe_open + SAFETENSORS_AVAILABLE = True +except ImportError: + SAFETENSORS_AVAILABLE = False + + +# Cache for Swift bridge availability check +_SWIFT_BRIDGE_AVAILABLE = None +_SWIFT_GPT_BRIDGE_CLASS = None + + +def deep_getattr(obj, attr: str, default=None): + """Get nested attribute from object using dot notation.""" + try: + for key in attr.split('.'): + obj = getattr(obj, key) + return obj + except AttributeError: + return default + + +def is_last_rank() -> bool: + """Check if current process is the last rank.""" + if not dist.is_initialized(): + return True + return dist.get_rank() == dist.get_world_size() - 1 + + +class LazyTensor: + """Lazy tensor wrapper for deferred loading.""" + def __init__(self, tensor=None, loader=None): + self.tensor = tensor + self.loader = loader + + def load(self): + if self.tensor is None: + return self.loader() + return self.tensor + + +class SafetensorLazyLoader: + """Lazy loader for safetensor files.""" + def __init__(self, hf_model_dir: str, is_peft_format: bool = False): + self.hf_model_dir = hf_model_dir + self.is_peft_format = is_peft_format + self._weight_map = {} + self._file_handles = {} + self._load_index() + + def _open_file(self, filename: str): + if filename not in self._file_handles: + file_path = os.path.join(self.hf_model_dir, filename) + self._file_handles[filename] = safe_open(file_path, framework='pt') + return self._file_handles[filename] + + def _load_index(self): + import json + index_path = os.path.join(self.hf_model_dir, 'model.safetensors.index.json') + if os.path.exists(index_path): + with open(index_path, 'r') as f: + self._weight_map = json.load(f).get('weight_map', {}) + else: + safetensors_fname = 'adapter_model.safetensors' if self.is_peft_format else 'model.safetensors' + safetensors_file = os.path.join(self.hf_model_dir, safetensors_fname) + if os.path.exists(safetensors_file): + with safe_open(safetensors_file, framework='pt') as f: + for key in f.keys(): + self._weight_map[key] = safetensors_fname + + def get_state_dict(self): + return {k: LazyTensor(loader=partial(self._load_tensor, key=k)) for k in self._weight_map.keys()} + + def _load_tensor(self, key): + return self._open_file(self._weight_map[key]).get_tensor(key) + + def close(self): + self._file_handles.clear() + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + +@dataclass +class TwinkleArgs: + """Args class that mimics megatron.training.get_args() return value. + + TODO: Remove swift dependency. This class is currently designed to be compatible + with external GPTBridge. Once independent bridge logic is implemented, this + can be simplified. + """ + # Model architecture + hidden_size: int = 4096 + num_attention_heads: int = 32 + num_query_groups: int = 32 + num_layers: int = 32 + ffn_hidden_size: int = 11008 + padded_vocab_size: int = 32000 + + # Model options + group_query_attention: bool = False + add_qkv_bias: bool = False + add_bias_linear: bool = False + qk_layernorm: bool = False + multi_latent_attention: bool = False + untie_embeddings_and_output_weights: bool = True + + # MoE + num_experts: Optional[int] = None + moe_shared_expert_intermediate_size: Optional[int] = None + moe_router_enable_expert_bias: bool = False + + # MLA (Multi-Latent Attention) - for DeepSeek models + q_lora_rank: Optional[int] = None + kv_lora_rank: int = 32 + + # MTP (Multi-Token Prediction) + mtp_num_layers: int = 0 + + # Parallelism + tensor_model_parallel_size: int = 1 + pipeline_model_parallel_size: int = 1 + expert_model_parallel_size: int = 1 + expert_tensor_parallel_size: int = 1 + context_parallel_size: int = 1 + sequence_parallel: bool = False + + distributed_timeout_minutes: int = 300000 + distributed_backend: str = 'nccl' + local_rank: int = 0 + rank: int = 0 + world_size: int = 1 + + # Paths and identifiers + model_dir: str = '' + hf_model_type: str = 'qwen2' + + # Task type + task_type: str = 'causal_lm' + + # Save settings + max_shard_size: str = '5GB' + + # Multimodal + is_multimodal: bool = False + + # Hub settings + use_hf: bool = False + hub_token: Optional[str] = None + + # Additional Megatron settings + fp16: bool = False + bf16: bool = True + accumulate_allreduce_grads_in_fp32: bool = False + async_tensor_model_parallel_allreduce: bool = False + use_distributed_optimizer: bool = False + overlap_grad_reduce: bool = False + overlap_param_gather: bool = False + + # Softmax type + softmax_type: str = 'vanilla' + + # Extra Megatron arguments + padding_free: bool = True + mlp_padding_free: bool = False + check_model: bool = True + initialize_embedding: bool = False + rope_scaling: Optional[Any] = None + torch_dtype: Optional[Any] = None + model: Optional[str] = None + model_type: Optional[str] = None + load_safetensors: Optional[bool] = None + save_safetensors: bool = True + adapters: Optional[Any] = None + merge_lora: Optional[bool] = None + + # Training settings + micro_batch_size: int = 1 + global_batch_size: int = 16 + recompute_granularity: str = 'selective' + recompute_method: Optional[str] = None + recompute_num_layers: Optional[int] = None + use_cpu_initialization: bool = False + deterministic_mode: bool = False + no_masked_softmax_fusion: bool = False + no_bias_dropout_fusion: Optional[bool] = None + no_bias_swiglu_fusion: bool = False + no_rope_fusion: Optional[bool] = None + + # LoRA settings + train_type: Optional[str] = None + lora_rank: int = 8 + lora_alpha: int = 8 + + @classmethod + def from_hf_config(cls, hf_config: Any, tp_size: int = 1, pp_size: int = 1, + ep_size: int = 1, etp_size: Optional[int] = None, + model_dir: str = '', padded_vocab_size: Optional[int] = None, + use_hf: bool = False, hub_token: Optional[str] = None): + """Create TwinkleArgs from HuggingFace config. + + Args: + hf_config: HuggingFace model configuration. + tp_size: Tensor parallel size. + pp_size: Pipeline parallel size. + ep_size: Expert parallel size. + etp_size: Expert tensor parallel size (defaults to tp_size). + model_dir: Path to model directory. + padded_vocab_size: Padded vocabulary size (auto-computed if None). + use_hf: Whether to use HuggingFace Hub (vs ModelScope). + hub_token: Hub token for authentication. + """ + import os + + vocab_size = getattr(hf_config, 'vocab_size', 32000) + if padded_vocab_size is None: + # Pad to multiple of tp_size * 128 for efficiency + divisor = tp_size * 128 + padded_vocab_size = ((vocab_size + divisor - 1) // divisor) * divisor + + num_attention_heads = getattr(hf_config, 'num_attention_heads', 32) + num_query_groups = getattr(hf_config, 'num_key_value_heads', num_attention_heads) + model_type = getattr(hf_config, 'model_type', 'qwen2') + + # Determine QKV bias - Qwen2 has bias by default but config doesn't expose it + if hasattr(hf_config, 'attention_bias'): + add_qkv_bias = hf_config.attention_bias + elif model_type in ('qwen2', 'qwen2_5'): + add_qkv_bias = True + else: + add_qkv_bias = False + + # MoE config + num_experts = getattr(hf_config, 'num_experts', None) or \ + getattr(hf_config, 'n_routed_experts', None) or \ + getattr(hf_config, 'num_local_experts', None) + + # QK layernorm (Qwen3) + qk_layernorm = getattr(hf_config, 'qk_layernorm', False) or \ + getattr(hf_config, 'use_qk_norm', False) + + # MLA settings (DeepSeek) + q_lora_rank = getattr(hf_config, 'q_lora_rank', None) + multi_latent_attention = q_lora_rank is not None or \ + getattr(hf_config, 'kv_lora_rank', None) is not None + + # Get distributed settings from environment + local_rank = int(os.environ.get('LOCAL_RANK', 0)) + rank = int(os.environ.get('RANK', 0)) + world_size = int(os.environ.get('WORLD_SIZE', 1)) + + return cls( + hidden_size=getattr(hf_config, 'hidden_size', 4096), + num_attention_heads=num_attention_heads, + num_query_groups=num_query_groups, + num_layers=getattr(hf_config, 'num_hidden_layers', 32), + ffn_hidden_size=getattr(hf_config, 'intermediate_size', 11008), + padded_vocab_size=padded_vocab_size, + group_query_attention=num_query_groups != num_attention_heads, + add_qkv_bias=add_qkv_bias, + add_bias_linear=getattr(hf_config, 'mlp_bias', False), + qk_layernorm=qk_layernorm, + multi_latent_attention=multi_latent_attention, + untie_embeddings_and_output_weights=not getattr(hf_config, 'tie_word_embeddings', False), + num_experts=num_experts, + moe_shared_expert_intermediate_size=getattr(hf_config, 'shared_expert_intermediate_size', None), + q_lora_rank=q_lora_rank, + kv_lora_rank=getattr(hf_config, 'kv_lora_rank', 32), + tensor_model_parallel_size=tp_size, + pipeline_model_parallel_size=pp_size, + expert_model_parallel_size=ep_size, + expert_tensor_parallel_size=etp_size or tp_size, + local_rank=local_rank, + rank=rank, + world_size=world_size, + model_dir=model_dir, + hf_model_type=model_type, + use_hf=use_hf, + hub_token=hub_token, + adapters=[], # Initialize as empty list + ) + + +# ============================================================================= +# GPTBridge Adapter +# TODO: Implement independent bridge logic to remove swift dependency. +# ============================================================================= +def _import_swift_bridge(): + """Import GPTBridge from external package. + + TODO: Implement independent bridge logic in twinkle. The weight conversion + between HuggingFace and Megatron formats should be self-contained. + + Returns: + GPTBridge class if available, None otherwise. + """ + global _SWIFT_BRIDGE_AVAILABLE, _SWIFT_GPT_BRIDGE_CLASS + + if _SWIFT_BRIDGE_AVAILABLE is not None: + return _SWIFT_GPT_BRIDGE_CLASS + + try: + from swift.utils import disable_safe_ddp_context_use_barrier + + with disable_safe_ddp_context_use_barrier(): + from swift.megatron.model.gpt_bridge import GPTBridge + + _SWIFT_BRIDGE_AVAILABLE = True + _SWIFT_GPT_BRIDGE_CLASS = GPTBridge + return GPTBridge + except ImportError as e: + _SWIFT_BRIDGE_AVAILABLE = False + _SWIFT_GPT_BRIDGE_CLASS = None + return None + except Exception as e: + import traceback + print(f"Warning: Failed to import GPTBridge: {e}") + traceback.print_exc() + _SWIFT_BRIDGE_AVAILABLE = False + _SWIFT_GPT_BRIDGE_CLASS = None + return None + + +def use_swift_bridge() -> bool: + """Check if GPTBridge is available.""" + _import_swift_bridge() + return _SWIFT_BRIDGE_AVAILABLE is True + + +class SwiftBridgeAdapter: + """Adapter to use swift's GPTBridge with twinkle's TwinkleArgs. + + TODO: Remove swift dependency. Implement independent bridge logic in twinkle. + + This class wraps swift's GPTBridge for weight loading/saving between + HuggingFace and Megatron formats. + """ + + def __init__(self, args: TwinkleArgs, hf_model=None, disable_tqdm: bool = False): + self.args = args + self.hf_model = hf_model + self.disable_tqdm = disable_tqdm + self._swift_bridge = None + + self._init_swift_bridge() + + def _init_swift_bridge(self): + """Initialize swift's GPTBridge with our args.""" + GPTBridge = _import_swift_bridge() + if GPTBridge is None: + raise ImportError( + "swift package is required for Megatron weight loading. " + "Please install: pip install ms-swift" + ) + + # Use Megatron's official set_args to set global args + from megatron.training.global_vars import set_args, get_args + + # Check if args already set + try: + existing_args = get_args() + # Args already initialized, we'll use existing + self._swift_bridge = GPTBridge(disable_tqmd=self.disable_tqdm) + except AssertionError: + # Args not initialized, set our args + set_args(self.args) + self._swift_bridge = GPTBridge(disable_tqmd=self.disable_tqdm) + + def load_weights(self, mg_model, hf_model_dir: str, is_peft_format: bool = False): + """Load weights from HuggingFace checkpoint into Megatron model.""" + self._swift_bridge.load_weights(mg_model, hf_model_dir, is_peft_format) + + def save_weights(self, mg_models, output_dir: str, hf_model_dir: str = None, is_peft_format: bool = False): + """Save weights in HuggingFace format.""" + self._swift_bridge.save_weights(mg_models, output_dir, is_peft_format) + + +def create_bridge_adapter( + hf_config: Any, + tp_size: int = 1, + pp_size: int = 1, + ep_size: int = 1, + model_dir: str = '', + padded_vocab_size: Optional[int] = None, +) -> SwiftBridgeAdapter: + """Create a bridge adapter for weight loading/saving. + + TODO: Remove swift dependency. Implement independent bridge logic. + + Args: + hf_config: HuggingFace model config. + tp_size: Tensor parallel size. + pp_size: Pipeline parallel size. + ep_size: Expert parallel size. + model_dir: Path to model directory. + padded_vocab_size: Padded vocabulary size. + + Returns: + SwiftBridgeAdapter instance. + """ + args = TwinkleArgs.from_hf_config( + hf_config, + tp_size=tp_size, + pp_size=pp_size, + ep_size=ep_size, + model_dir=model_dir, + padded_vocab_size=padded_vocab_size, + ) + + return SwiftBridgeAdapter(args) + + +def create_megatron_model_with_swift( + model_path: str, + tp_size: int = 1, + pp_size: int = 1, + ep_size: int = 1, + params_dtype=None, + use_cpu_initialization: bool = True, + attention_backend: str = 'unfused', + load_weights: bool = True, +): + """Create Megatron model using swift's initialization flow. + + TODO: Remove swift dependency. Implement independent initialization logic. + + Args: + model_path: Path to HuggingFace model or model ID. + tp_size: Tensor parallel size. + pp_size: Pipeline parallel size. + ep_size: Expert parallel size. + params_dtype: Parameter dtype (default: torch.bfloat16). + use_cpu_initialization: Initialize on CPU first (for memory efficiency). + attention_backend: Attention backend ('unfused' for precision, 'flash' for speed). + load_weights: Whether to load weights. + + Returns: + Tuple of (model, bridge, megatron_model_meta). + """ + import torch + from transformers import AutoConfig + + if params_dtype is None: + params_dtype = torch.bfloat16 + + # Download model if needed + if not os.path.isdir(model_path): + try: + from modelscope import snapshot_download + model_path = snapshot_download(model_path) + except ImportError: + from huggingface_hub import snapshot_download + model_path = snapshot_download(model_path) + + # Load HF config + hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + + # Import Swift modules with barrier disabled + from swift.utils import disable_safe_ddp_context_use_barrier + + with disable_safe_ddp_context_use_barrier(): + from swift.megatron import ( + MegatronArguments, convert_hf_config, get_megatron_model_meta + ) + + from megatron.training.initialize import initialize_megatron + from megatron.training import get_args + + # Check if Megatron is already initialized + try: + existing_args = get_args() + megatron_initialized = True + except AssertionError: + megatron_initialized = False + + # Get model meta first to get extra_args_provider + megatron_model_meta = get_megatron_model_meta(hf_config.model_type) + if megatron_model_meta is None: + raise ValueError(f'Model type {hf_config.model_type} not supported by Swift') + + if not megatron_initialized: + # Convert HF config to Megatron config kwargs + config_kwargs = convert_hf_config(hf_config) + + # Create MegatronArguments + megatron_args = MegatronArguments( + model=model_path, + tensor_model_parallel_size=tp_size, + pipeline_model_parallel_size=pp_size, + expert_model_parallel_size=ep_size, + torch_dtype=params_dtype, + use_cpu_initialization=use_cpu_initialization, + attention_backend=attention_backend, + **config_kwargs, + ) + + # Parse to Megatron format + extra_args = megatron_args.parse_to_megatron() + + # Initialize Megatron + extra_args_provider = megatron_model_meta.extra_args_provider + initialize_megatron(extra_args_provider=extra_args_provider, args_defaults=extra_args) + + # Determine pre_process and post_process based on pipeline stage + from megatron.core import parallel_state as mpu + pre_process = mpu.is_pipeline_first_stage() + post_process = mpu.is_pipeline_last_stage() + + model = megatron_model_meta.model_provider(pre_process=pre_process, post_process=post_process) + + # Load weights if requested + bridge = None + if load_weights: + bridge = megatron_model_meta.bridge_cls() + bridge.load_weights(model, model_path) + + return model, bridge, megatron_model_meta + + +class MegatronBridgeInitializer: + """Megatron model initializer using bridge-based initialization flow. + + TODO: Remove swift dependency. Implement independent initialization logic. + + Example: + initializer = MegatronBridgeInitializer( + tp_size=2, + pp_size=1, + params_dtype=torch.bfloat16, + ) + model = initializer.create_model('Qwen/Qwen2.5-7B-Instruct') + """ + + def __init__( + self, + tp_size: int = 1, + pp_size: int = 1, + ep_size: int = 1, + params_dtype=None, + use_cpu_initialization: bool = True, + attention_backend: str = 'flash', # Use flash for training performance + ): + """Initialize MegatronBridgeInitializer. + + Args: + tp_size: Tensor parallel size. + pp_size: Pipeline parallel size. + ep_size: Expert parallel size. + params_dtype: Parameter dtype (default: torch.bfloat16). + use_cpu_initialization: Initialize on CPU first. + attention_backend: Attention backend. + """ + import torch + + self.tp_size = tp_size + self.pp_size = pp_size + self.ep_size = ep_size + self.params_dtype = params_dtype if params_dtype is not None else torch.bfloat16 + self.use_cpu_initialization = use_cpu_initialization + self.attention_backend = attention_backend + + self._model = None + self._bridge = None + self._model_meta = None + self._hf_config = None + + def create_model( + self, + model_path: str, + load_weights: bool = True, + ): + """Create Megatron model from HuggingFace checkpoint. + + Args: + model_path: Path to HuggingFace model or model ID. + load_weights: Whether to load weights. + + Returns: + Megatron model. + """ + from transformers import AutoConfig + + # Download model if needed + if not os.path.isdir(model_path): + try: + from modelscope import snapshot_download + model_path = snapshot_download(model_path) + except ImportError: + from huggingface_hub import snapshot_download + model_path = snapshot_download(model_path) + + # Store HF config + self._hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + + self._model, self._bridge, self._model_meta = create_megatron_model_with_swift( + model_path=model_path, + tp_size=self.tp_size, + pp_size=self.pp_size, + ep_size=self.ep_size, + params_dtype=self.params_dtype, + use_cpu_initialization=self.use_cpu_initialization, + attention_backend=self.attention_backend, + load_weights=load_weights, + ) + + return self._model + + @property + def hf_config(self): + """Get the HuggingFace config.""" + return self._hf_config + + @property + def bridge(self): + """Get the Swift bridge instance.""" + return self._bridge + + @property + def model_meta(self): + """Get the Megatron model meta.""" + return self._model_meta + + def load_weights(self, model, model_path: str): + """Load weights into an existing model. + + Args: + model: Megatron model. + model_path: Path to HuggingFace checkpoint. + """ + if self._bridge is None: + # Create bridge from model meta + if self._model_meta is None: + raise ValueError("Must call create_model first or provide model_meta") + self._bridge = self._model_meta.bridge_cls() + + self._bridge.load_weights(model, model_path) + + def save_weights(self, models, output_dir: str, is_peft_format: bool = False): + """Save weights in HuggingFace format. + + Args: + models: Megatron model(s). + output_dir: Output directory. + is_peft_format: Whether to save in PEFT format. + """ + if self._bridge is None: + raise ValueError("Must load weights first") + + if not isinstance(models, (list, tuple)): + models = [models] + + self._bridge.save_weights(models, output_dir, is_peft_format=is_peft_format) diff --git a/src/twinkle/megatron/tuners/lora.py b/src/twinkle/megatron/tuners/lora.py index 17714e0d..dbaefe7a 100644 --- a/src/twinkle/megatron/tuners/lora.py +++ b/src/twinkle/megatron/tuners/lora.py @@ -1,6 +1,4 @@ # Copyright (c) twinkle authors. All rights reserved. -# Code adapted from huggingface/peft and ms-swift -# [SWIFT] Core LoRA implementation adapted from swift's megatron tuners. """Megatron-compatible LoRA implementation with Tensor Parallel support.""" import math import warnings diff --git a/src/twinkle/megatron/utils.py b/src/twinkle/megatron/utils.py index d53bbf7b..68e5663e 100644 --- a/src/twinkle/megatron/utils.py +++ b/src/twinkle/megatron/utils.py @@ -1,7 +1,4 @@ # Copyright (c) twinkle authors. All rights reserved. -# Code reference: Adapted from ms-swift with modifications for twinkle's multi-tenant architecture. -# Original code markers: Functions marked with [SWIFT] are adapted from swift, -# Functions marked with [TWINKLE] are original implementations. """Utility functions for Megatron-Core integration.""" from contextlib import contextmanager from typing import Any, Dict, List, Optional, Tuple @@ -27,9 +24,7 @@ mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') -# ============================================================================= -# [SWIFT] Config mapping from HuggingFace to Megatron - adapted from swift -# ============================================================================= +# Config mapping from HuggingFace to Megatron CONFIG_MAPPING = { 'num_layers': ['num_hidden_layers'], 'hidden_size': ['hidden_size'], @@ -75,14 +70,9 @@ } -# ============================================================================= -# [TWINKLE] Multi-tenant Process Group Management -# ============================================================================= class TenantProcessGroupManager: """Manager for multi-tenant process groups. - [TWINKLE] This is an original implementation for twinkle's multi-tenant architecture. - In a multi-tenant scenario, multiple users may share the same base model in a single process, each with their own LoRA adapters. To avoid communication interference between tenants, we need to maintain separate process groups for each tenant. @@ -442,12 +432,12 @@ def get_tenant_manager() -> TenantProcessGroupManager: # ============================================================================= -# [SWIFT] Layer finding utilities - adapted from swift + # ============================================================================= def find_layers(model: nn.Module, cond_fn) -> List[str]: """Find all layers in model matching condition function. - [SWIFT] Adapted from swift. + Args: model: The model to search. @@ -466,7 +456,7 @@ def find_layers(model: nn.Module, cond_fn) -> List[str]: def find_all_linears(model: nn.Module) -> List[str]: """Find all linear layers suitable for LoRA in a Megatron model. - [SWIFT] Adapted from swift. + Args: model: The Megatron model. @@ -487,7 +477,7 @@ def _cond(name: str, module: nn.Module) -> bool: def find_router(model: nn.Module) -> List[str]: """Find all MoE router layers in a Megatron model. - [SWIFT] Adapted from swift. + Args: model: The Megatron model. @@ -501,7 +491,7 @@ def find_router(model: nn.Module) -> List[str]: def find_embedding(model: nn.Module) -> List[str]: """Find all embedding layers in a Megatron model. - [SWIFT] Adapted from swift. + Args: model: The Megatron model. @@ -515,7 +505,7 @@ def find_embedding(model: nn.Module) -> List[str]: def get_target_modules(model: nn.Module, target_modules: List[str]) -> List[str]: """Expand target module specifications to actual module names. - [SWIFT] Adapted from swift. + Args: model: The Megatron model. @@ -540,7 +530,7 @@ def get_target_modules(model: nn.Module, target_modules: List[str]) -> List[str] def set_linear_is_expert(model: nn.Module): """Mark expert linear layers in MoE models. - [SWIFT] Adapted from swift. + Args: model: The Megatron model. @@ -574,12 +564,12 @@ def deep_getattr(obj: Any, attr: str, default: Any = None) -> Any: # ============================================================================= -# [SWIFT] Config conversion - adapted from swift with Qwen3 enhancements + # ============================================================================= def _convert_hf_config(config, _internal_call: bool = False) -> Dict[str, Any]: """Convert HuggingFace config to Megatron config dict. - [SWIFT] Adapted from swift. + Args: config: HuggingFace model config. @@ -629,7 +619,7 @@ def _convert_hf_config(config, _internal_call: bool = False) -> Dict[str, Any]: def convert_hf_config(config) -> Dict[str, Any]: """Convert HuggingFace config to Megatron-compatible config. - [SWIFT] Adapted from swift with Qwen3 specific handling. + Args: config: HuggingFace model config. @@ -654,6 +644,14 @@ def convert_hf_config(config) -> Dict[str, Any]: first_k_dense_replace = res.pop('first_k_dense_replace', None) n_shared_experts = res.pop('n_shared_experts', None) + # ==== Qwen2/Qwen2.5 Model specific settings ==== + if llm_architectures == 'Qwen2ForCausalLM': + # Qwen2/Qwen2.5 uses bias=True for Q, K, V projections (hardcoded in transformers) + # but the config doesn't have 'attention_bias' field + if 'add_qkv_bias' not in res: + res['add_qkv_bias'] = True + res['swiglu'] = True + # ==== Qwen3 Dense Model specific settings ==== if llm_architectures == 'Qwen3ForCausalLM': res['qk_layernorm'] = True @@ -701,7 +699,7 @@ def convert_hf_config(config) -> Dict[str, Any]: def patch_deepcopy(): """Context manager to handle tp_group in deepcopy operations. - [SWIFT] Adapted from swift. + WHY THIS IS NECESSARY: ---------------------- @@ -751,7 +749,7 @@ def new_deepcopy(x, *args, **kwargs): # ============================================================================= -# [SWIFT] Sharded state dict for tuners - adapted from swift + # ============================================================================= def tuners_sharded_state_dict( module: nn.Module, @@ -761,7 +759,7 @@ def tuners_sharded_state_dict( ) -> Dict[str, Any]: """Generate sharded state dict for PEFT tuners. - [SWIFT] Adapted from swift. + Args: module: The module to generate state dict for. @@ -867,12 +865,12 @@ def prepare_lora_model( # ============================================================================= -# [SWIFT] Layer spec utilities - adapted from swift + # ============================================================================= def get_local_layer_specs(config, layer_specs: List, vp_stage: Optional[int] = None): """Get local layer specifications for current pipeline rank. - [SWIFT] Adapted from swift. + Args: config: Megatron transformer config. @@ -937,7 +935,7 @@ def get_padding_to( # ============================================================================= -# [SWIFT] Forward step helper - adapted from swift + # ============================================================================= def forward_step_helper(model: nn.Module, inputs: Dict[str, Any], config) -> Optional[torch.Tensor]: """Helper for pipeline parallel forward step. @@ -955,9 +953,17 @@ def forward_step_helper(model: nn.Module, inputs: Dict[str, Any], config) -> Opt from megatron.core.inference.communication_utils import recv_from_prev_pipeline_rank_, send_to_next_pipeline_rank if mpu.is_pipeline_first_stage(): - micro_batch_size = 1 # use qkv_format 'thd' + # Get micro_batch_size from input tensor, not config + # For padding_free (qkv_format 'thd'), use 1 + micro_batch_size = 1 if not getattr(config, 'padding_free', False): - micro_batch_size = config.micro_batch_size + # Infer batch size from input_ids or position_ids + if 'input_ids' in inputs: + micro_batch_size = inputs['input_ids'].shape[0] + elif 'position_ids' in inputs: + micro_batch_size = inputs['position_ids'].shape[0] + else: + micro_batch_size = 1 seq_length = inputs['position_ids'].shape[-1] if config.sequence_parallel: seq_length //= mpu.get_tensor_model_parallel_world_size() diff --git a/src/twinkle/model/megatron.py b/src/twinkle/model/megatron.py index ce5c1830..77c3f691 100644 --- a/src/twinkle/model/megatron.py +++ b/src/twinkle/model/megatron.py @@ -110,6 +110,7 @@ def __init__( mixed_precision: Literal['no', 'fp16', 'bf16'] = 'bf16', use_distributed_optimizer: bool = True, load_weights: bool = True, + use_megatron_bridge: bool = True, # Use bridge-based initialization (recommended) **kwargs, ): check_megatron_available() @@ -118,6 +119,14 @@ def __init__( self.model_id = pretrained_model_name_or_path self.device_mesh = device_mesh self.mixed_precision = mixed_precision + self.use_megatron_bridge = use_megatron_bridge + + # Load HuggingFace config first + model_path = HubOperation.download_model(pretrained_model_name_or_path) + self._load_hf_config(model_path) + + # Store model_path for later use + self._model_path = model_path # Create Megatron strategy self.strategy = MegatronStrategy( @@ -130,12 +139,9 @@ def __init__( mixed_precision=mixed_precision, ) - # Initialize parallel state - self.strategy.initialize() - - # Load HuggingFace config - model_path = HubOperation.download_model(pretrained_model_name_or_path) - self._load_hf_config(model_path) + # Initialize parallel state (skip if using bridge init, as it handles this) + if not use_megatron_bridge: + self.strategy.initialize() # Create Megatron model self.model = self._create_megatron_model(model_path, load_weights, **kwargs) @@ -168,13 +174,89 @@ def _create_megatron_model( Returns: Megatron model on GPU. """ - from twinkle.megatron.model.initializer import MegatronModelInitializer - params_dtype = torch.bfloat16 if self.mixed_precision == 'fp16': params_dtype = torch.float16 elif self.mixed_precision == 'no': params_dtype = torch.float32 + + if self.use_megatron_bridge: + # Use bridge-based initialization (recommended) + # This ensures all patches are applied and config is correctly generated + return self._create_megatron_model_with_bridge(model_path, load_weights, params_dtype, **kwargs) + else: + # Use twinkle's native initialization + return self._create_megatron_model_native(model_path, load_weights, params_dtype, **kwargs) + + def _create_megatron_model_with_bridge( + self, + model_path: str, + load_weights: bool, + params_dtype: torch.dtype, + **kwargs, + ) -> nn.Module: + """Create Megatron model using bridge-based initialization flow. + + This approach uses the bridge initialization which includes: + - Proper config conversion from HuggingFace to Megatron format + - Correct Megatron initialization (initialize_megatron) + - Correct model creation (model_provider) + - All necessary patches (RoPE, TransformerLayer, etc.) + + Args: + model_path: Path to HuggingFace model. + load_weights: Whether to load weights. + params_dtype: Parameter dtype. + **kwargs: Additional arguments. + + Returns: + Megatron model on GPU. + """ + from twinkle.megatron.model.swift_bridge import MegatronBridgeInitializer + + # Create bridge-based initializer + self._bridge_initializer = MegatronBridgeInitializer( + tp_size=self.strategy.tp_size, + pp_size=self.strategy.pp_size, + ep_size=self.strategy.ep_size, + params_dtype=params_dtype, + use_cpu_initialization=True, + attention_backend='flash', # Use flash for training performance + ) + + # Create model (this calls initialize_megatron internally) + model = self._bridge_initializer.create_model(model_path, load_weights=load_weights) + + # Update strategy state since bridge has initialized Megatron + self.strategy._initialized = True + self.strategy._parallel_state = mpu + + # Move to GPU + model = self._move_model_to_gpu(model) + + return model + + def _create_megatron_model_native( + self, + model_path: str, + load_weights: bool, + params_dtype: torch.dtype, + **kwargs, + ) -> nn.Module: + """Create Megatron model using twinkle's native initialization. + + This is the fallback method when bridge is not available. + + Args: + model_path: Path to HuggingFace model. + load_weights: Whether to load weights. + params_dtype: Parameter dtype. + **kwargs: Additional arguments. + + Returns: + Megatron model on GPU. + """ + from twinkle.megatron.model.initializer import MegatronModelInitializer initializer = MegatronModelInitializer( tp_size=self.strategy.tp_size, @@ -207,7 +289,12 @@ def _move_model_to_gpu(self, model: nn.Module) -> nn.Module: return model def _lazy_wrap_model(self): - """Lazily wrap model with distributed wrapper.""" + """Lazily wrap model with distributed wrapper. + + Note: This should only be called after prepare_training() has been + executed on all workers. Direct calls from forward() may cause + deadlocks if not all DP ranks are participating. + """ if not self._model_wrapped: # Find an optimizer from any adapter group (prefer default, then first available) optimizer = None @@ -227,6 +314,16 @@ def _lazy_wrap_model(self): self.model, optimizer = self.strategy.wrap_model(self.model, optimizer) self.optimizer_group[optimizer_adapter].optimizer = optimizer self._model_wrapped = True + + @remote_function(dispatch='all') + def prepare_training(self, **kwargs): + """Prepare model for training. + + Note: In Ray-based Megatron training, we skip DDP wrapping to avoid + deadlocks from collective operations. Each DP replica trains independently. + This method still calls _lazy_wrap_model for any non-DDP setup needed. + """ + self._lazy_wrap_model() @remote_function() def forward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs): @@ -308,7 +405,12 @@ def _forward_step_simple(self, inputs: Dict[str, Any]) -> Dict[str, Any]: return {'logits': outputs} def _forward_step_pipeline(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - """Forward step with pipeline parallelism.""" + """Forward step with pipeline parallelism. + + Note: For PP > 1, the forward pass is handled by Megatron's pipeline scheduler + in forward_backward(). This method is for simple forward-only inference. + For training, use forward_backward() which uses get_forward_backward_func(). + """ from twinkle.megatron.utils import forward_step_helper model = self.strategy.unwrap_model(self.model) @@ -385,7 +487,12 @@ def backward(self, **kwargs): @remote_function(collect='avg') def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs): - """Combined forward and backward pass. + """Combined forward and backward pass using Megatron's scheduler. + + Always uses Megatron's get_forward_backward_func() which handles: + - Pipeline scheduling (1F1B, interleaved, or no-pipeline) + - Communication between stages (using proper process groups for multi-tenant isolation) + - Gradient accumulation Args: inputs: Model inputs. @@ -394,10 +501,133 @@ def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Tr Returns: Loss value. """ - self.forward(inputs=inputs, **kwargs) - loss = self.calculate_loss(**kwargs) - self.backward(**kwargs) - return loss + from functools import partial + from megatron.core.pipeline_parallel import get_forward_backward_func + + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + self._lazy_wrap_model() + + # Encode inputs if needed + if isinstance(inputs, dict) and 'input_ids' not in inputs: + if optimizer_config.template is not None: + inputs = optimizer_config.template.encode(inputs) + if isinstance(inputs, list) and 'input_ids' not in inputs[0]: + if optimizer_config.template is not None: + inputs = optimizer_config.template.batch_encode(inputs) + + # Process inputs + processor = optimizer_config.processor + if processor is not None: + inputs = processor(inputs) + + # Store labels before removing from inputs + labels = inputs.pop('labels', None) + + # Get sequence length and batch size + seq_length = inputs['input_ids'].shape[1] if 'input_ids' in inputs else 1 + micro_batch_size = inputs['input_ids'].shape[0] if 'input_ids' in inputs else 1 + + # Move labels to GPU if needed + if labels is not None and not isinstance(labels, torch.Tensor): + labels = torch.tensor(labels, device=torch.cuda.current_device()) + elif labels is not None: + labels = labels.to(torch.cuda.current_device()) + + # Define loss function that matches Megatron's expected signature + # loss_func(output_tensor) -> (loss, {str: tensor}) + def loss_func(labels_tensor, loss_instance, output_tensor): + if labels_tensor is None or loss_instance is None: + loss = torch.tensor(0.0, device=output_tensor.device, requires_grad=True) + return loss, {'loss': loss} + + inputs_dict = {'labels': labels_tensor} + outputs_dict = {'logits': output_tensor} + loss = loss_instance(inputs_dict, outputs_dict) + + # Megatron expects (loss, {str: tensor}) for logging + return loss, {'loss': loss.detach()} + + # Define forward step function for Megatron + # forward_step_func(data_iterator, model) -> (output_tensor, partial(loss_func)) + def forward_step_func(data_iterator, model): + batch = next(data_iterator) + input_ids = batch.get('input_ids') + position_ids = batch.get('position_ids') + attention_mask = batch.get('attention_mask') + + # Create position_ids if not provided + if position_ids is None and input_ids is not None: + position_ids = torch.arange( + input_ids.shape[1], + device=input_ids.device, + dtype=torch.long, + ).unsqueeze(0).expand(input_ids.shape[0], -1) + + # Forward pass + output_tensor = model( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + ) + + # Return output and partial loss function + return output_tensor, partial(loss_func, labels, optimizer_config.loss_instance) + + # Get Megatron's forward-backward function + # This automatically selects the right scheduler based on PP config: + # - PP > 1: forward_backward_pipelining_without_interleaving (or with interleaving if VPP) + # - PP = 1: forward_backward_no_pipelining + forward_backward_func = get_forward_backward_func() + + # Create single-item iterator + data_iter = iter([inputs]) + + # Run forward-backward with Megatron's scheduler + # Megatron handles all communication internally using proper process groups + losses = forward_backward_func( + forward_step_func=forward_step_func, + data_iterator=data_iter, + model=[self.model], + num_microbatches=1, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + forward_only=False, + ) + + # Extract loss from results (only last PP stage returns non-empty) + loss = 0.0 + if losses: + for loss_dict in losses: + if isinstance(loss_dict, dict) and 'loss' in loss_dict: + loss = loss_dict['loss'] + break + elif isinstance(loss_dict, torch.Tensor): + loss = loss_dict + break + + # For PP > 1, broadcast loss from last PP stage to all ranks + from megatron.core import parallel_state as mpu + if mpu.get_pipeline_model_parallel_world_size() > 1: + if isinstance(loss, torch.Tensor): + loss_tensor = loss.detach().clone() + else: + loss_tensor = torch.tensor(loss, dtype=torch.float32, device=torch.cuda.current_device()) + + # Broadcast from last PP stage (rank with pipeline_model_parallel_rank == pp_size - 1) + src_rank = mpu.get_pipeline_model_parallel_last_rank() + torch.distributed.broadcast( + loss_tensor, + src=src_rank, + group=mpu.get_pipeline_model_parallel_group() + ) + loss = loss_tensor.item() + + optimizer_config.cur_step += 1 + + if isinstance(loss, torch.Tensor): + return loss.detach().cpu().float().numpy() + return float(loss) @remote_function(dispatch='all') def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type: int = 2, **kwargs): @@ -422,13 +652,16 @@ def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type: int = 2, **kwarg def step(self, **kwargs): """Optimizer step. - For PEFT models, gradients are NOT synchronized across DP ranks - because each DP replica trains independently with different data. - This is a common pattern for PEFT training where gradient averaging - is not strictly necessary. + For DDP-wrapped models: + - Gradients are synchronized automatically during backward via DDP + + For non-DDP models (e.g., PEFT/LoRA): + - Gradients are NOT synchronized across DP ranks + - Each DP replica trains independently with different data + - This is a common pattern for PEFT training where the overhead of + gradient averaging is not worth the benefit - Note: Uses dispatch='all' to ensure all workers execute this method, - though gradient sync is disabled for PEFT models. + Note: Uses dispatch='all' to ensure all workers execute this method. Args: **kwargs: Additional arguments. @@ -439,10 +672,13 @@ def step(self, **kwargs): if not optimizer_config.do_grad_sync(kwargs.get('gradient_accumulation_steps')): return - # Note: For PEFT/LoRA models, we skip gradient synchronization across DP ranks. - # Each DP replica trains independently. This avoids distributed communication - # complexity and is acceptable for most PEFT training scenarios. - # If gradient averaging is needed, use DDP-wrapped models instead. + # For DDP-wrapped models, gradients are already synchronized during backward + if self._is_model_ddp_wrapped(): + # For Megatron DDP, ensure gradient buffers are finalized + if hasattr(self.model, 'finish_grad_sync'): + self.model.finish_grad_sync() + # For non-DDP models (e.g., PEFT), we skip gradient synchronization + # Each DP replica trains independently, which is acceptable for PEFT optimizer = optimizer_config.optimizer assert optimizer is not None, 'Set optimizer correctly before stepping' @@ -470,6 +706,8 @@ def _get_unwrapped_model(self) -> nn.Module: def zero_grad(self, **kwargs): """Zero gradients. + For DDP-wrapped models, also zeros the DDP gradient buffers. + Args: **kwargs: Additional arguments. """ @@ -482,6 +720,10 @@ def zero_grad(self, **kwargs): optimizer = optimizer_config.optimizer if optimizer is not None: optimizer.zero_grad(**kwargs) + + # For Megatron DDP, zero the gradient buffer + if self._is_model_ddp_wrapped() and hasattr(self.model, 'zero_grad_buffer'): + self.model.zero_grad_buffer() @remote_function() def lr_step(self, **kwargs): @@ -598,26 +840,38 @@ def save(self, output_dir: str, **kwargs): self._save_tokenizer(output_dir, adapter_name) def _save_hf_format(self, output_dir: str, adapter_name: str): - """Save in HuggingFace format using swift's GPTBridge.""" + """Save in HuggingFace format using bridge adapter.""" from twinkle.megatron.model.bridge import TwinkleBridgeAdapter import os - # Only save from last PP rank + # Only save from last PP rank and first DP rank to avoid conflicts if not self.strategy.is_pipeline_last_stage(): return + + # Only let DP rank 0 save to avoid file conflicts + if hasattr(self.strategy, 'dp_rank') and self.strategy.dp_rank != 0: + return + + # Also check via parallel_state if available + try: + from megatron.core import parallel_state as mpu + if mpu.is_initialized() and mpu.get_data_parallel_rank() != 0: + return + except (ImportError, AssertionError): + pass os.makedirs(output_dir, exist_ok=True) - # Use TwinkleBridgeAdapter which wraps swift's GPTBridge + # Use TwinkleBridgeAdapter for weight conversion adapter = TwinkleBridgeAdapter( hf_config=self.hf_config, tp_size=self.strategy.tp_size, pp_size=self.strategy.pp_size, ep_size=self.strategy.ep_size, - model_path=self.pretrained_model_name_or_path, + model_path=self.model_id, ) - # Use swift's bridge to save weights + # Use bridge to save weights in HuggingFace format adapter.save_weights([self.model], output_dir, is_peft_format=False) # Save config @@ -667,8 +921,6 @@ def _patch_peft_for_megatron(cls): Megatron's TransformerConfig doesn't have a .get() method like HuggingFace configs. This patch handles the AttributeError that occurs when PEFT tries to check tie_word_embeddings. - - Reference: swift/swift/megatron/init.py::_patch_peft_BaseTuner """ if cls._peft_patched: return @@ -776,6 +1028,8 @@ def add_adapter_to_model( self.optimizer_group[adapter_name].template = default_config.template if default_config.processor: self.optimizer_group[adapter_name].processor = default_config.processor + if default_config.loss_instance: + self.optimizer_group[adapter_name].loss_instance = default_config.loss_instance @remote_function() def set_template(self, template_cls: Union[Type[template.Template], str], **kwargs): diff --git a/src/twinkle/model/strategy/megatron.py b/src/twinkle/model/strategy/megatron.py index 487c77b9..c4918e0c 100644 --- a/src/twinkle/model/strategy/megatron.py +++ b/src/twinkle/model/strategy/megatron.py @@ -317,99 +317,177 @@ def get_params_dtype(self) -> torch.dtype: return torch.float16 return torch.float32 + def _get_transformer_config(self, model: nn.Module): + """Get TransformerConfig from model, handling PEFT wrappers. + + Args: + model: The model (may be wrapped with PEFT). + + Returns: + TransformerConfig if found, None otherwise. + """ + # Direct config attribute + config = getattr(model, 'config', None) + if config is not None and hasattr(config, 'tensor_model_parallel_size'): + return config + + # PEFT model: model.base_model.model.config + if hasattr(model, 'base_model'): + base = model.base_model + if hasattr(base, 'model'): + config = getattr(base.model, 'config', None) + if config is not None and hasattr(config, 'tensor_model_parallel_size'): + return config + # Try base.config + config = getattr(base, 'config', None) + if config is not None and hasattr(config, 'tensor_model_parallel_size'): + return config + + # Wrapped model: model.model.config + if hasattr(model, 'model'): + config = getattr(model.model, 'config', None) + if config is not None and hasattr(config, 'tensor_model_parallel_size'): + return config + + # Recursive search through modules + for name, module in model.named_modules(): + config = getattr(module, 'config', None) + if config is not None and hasattr(config, 'tensor_model_parallel_size'): + return config + + return None + def wrap_model( self, model: nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, + use_distributed_optimizer: bool = True, ) -> Tuple[nn.Module, Optional[torch.optim.Optimizer]]: - """Wrap model with distributed wrapper for data parallelism. + """Wrap model with Megatron DDP for data parallelism. + + This method behaves differently based on twinkle's execution mode: - In Megatron, TP/PP/CP/EP parallelism is already handled during model creation - (via TransformerConfig and parallel_state). This method only handles Data - Parallel (DP) wrapping, which synchronizes gradients across DP ranks. + **Local mode (torchrun)**: + - Uses Megatron native DDP wrapping + - All processes are synchronized by torchrun, so collective ops work - For PEFT/LoRA models: - - We skip DDP wrapping to avoid compatibility issues - - Gradients are synchronized manually via all_reduce_gradients() - - This is more flexible and works reliably with dynamically added LoRA modules + **Ray mode**: + - Currently skips DDP wrapping to avoid deadlocks + - Ray's asynchronous actor model makes collective synchronization hard + - Each DP replica trains independently - For full model training (non-PEFT): - - Consider using Megatron's native training.setup_model_and_optimizer() - - Or use Megatron DDP with proper TransformerConfig + **Transformers/Accelerate comparison**: + - Accelerate's `prepare()` works in Ray because it's a local operation + - Megatron DDP's `broadcast_params()` is a collective that needs sync Args: - model: The Megatron model to wrap (already parallelized via TP/PP). - optimizer: Optional optimizer (not wrapped here; use DistributedOptimizer separately if needed). + model: The Megatron model (already has TP/PP via TransformerConfig). + optimizer: Optional optimizer. + use_distributed_optimizer: Whether to use distributed optimizer. Returns: - Tuple of (wrapped_model, wrapped_optimizer). - For PEFT models, wrapped_model is the original model (no DDP wrapper). + Tuple of (wrapped_model, optimizer). """ if not self._initialized: self.initialize() - # Check if this is a PEFT/LoRA model - is_peft_model = hasattr(model, 'peft_config') or hasattr(model, 'base_model') - - if is_peft_model: - # For PEFT models, skip DDP wrapping entirely. - # Reasons: - # 1. PEFT models have dynamically added modules that may cause issues with DDP - # 2. LoRA typically has very few trainable parameters, so manual gradient sync is efficient - # 3. Megatron DDP requires TransformerConfig which may not be accessible after PEFT wrapping - # 4. PyTorch DDP has device placement issues when model uses CPU initialization - # - # Instead, gradients should be synchronized manually using all_reduce_gradients() - # after backward() and before optimizer.step(). - return model, optimizer + # Determine execution mode + import os + twinkle_mode = os.environ.get('TWINKLE_MODE', 'local') - # For non-PEFT models, we can use Megatron DDP or PyTorch DDP + # Check DP world size dp_group = self.dp_group - if dp_group is None or dist.get_world_size(dp_group) <= 1: - # No DP needed (single GPU or no DP group) + dp_world_size = 1 + if dp_group is not None: + dp_world_size = dist.get_world_size(dp_group) + + if dp_world_size <= 1: + # No DP needed (single GPU or TP-only) return model, optimizer - # Get model config for Megatron DDP - config = getattr(model, 'config', None) + if twinkle_mode == 'ray': + # In Ray mode, skip DDP for now due to collective sync issues + # TODO: Implement Ray-compatible DDP with barrier synchronization + import warnings + warnings.warn( + "Skipping Megatron DDP in Ray mode. Each DP replica trains independently. " + "For synchronized training, use torchrun (TWINKLE_MODE=local)." + ) + return model, optimizer - # Check if model is on GPU (required for DDP) - model_device = next(model.parameters()).device - if model_device.type == 'cpu': - # Model is on CPU, need to move to GPU first - # This happens when use_cpu_initialization=True - local_rank = dist.get_rank() % torch.cuda.device_count() - model = model.to(f'cuda:{local_rank}') + # Local mode (torchrun): Use Megatron native DDP + return self._wrap_with_megatron_ddp(model, optimizer, use_distributed_optimizer) + + def _wrap_with_megatron_ddp( + self, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer], + use_distributed_optimizer: bool, + ) -> Tuple[nn.Module, Optional[torch.optim.Optimizer]]: + """ + Wrap model with Megatron native DDP (for torchrun mode). + """ + from megatron.core.distributed import DistributedDataParallelConfig + from megatron.core.transformer.module import Float16Module + + # Get TransformerConfig from model + config = self._get_transformer_config(model) + if config is None: + import warnings + warnings.warn( + "Could not find TransformerConfig. Skipping DDP wrapping. " + "Gradient sync will need to be done manually." + ) + return model, optimizer - if config is not None and hasattr(config, 'tensor_model_parallel_size'): - # Model has TransformerConfig, use Megatron DDP - try: - from megatron.core.distributed import DistributedDataParallelConfig - ddp_config = DistributedDataParallelConfig( - grad_reduce_in_fp32=True, - overlap_grad_reduce=False, - use_distributed_optimizer=self.use_distributed_optimizer, - check_for_nan_in_grad=False, - bucket_size=None, # No bucketing for simpler gradient sync - ) - wrapped_model = MegatronDDP( - config=config, - ddp_config=ddp_config, - module=model, - ) - return wrapped_model, optimizer - except (ImportError, TypeError) as e: - # Fallback to PyTorch DDP if Megatron DDP fails - pass - - # Fallback: PyTorch DDP for models without TransformerConfig - from torch.nn.parallel import DistributedDataParallel as TorchDDP - wrapped_model = TorchDDP( - model, - process_group=dp_group, - # Note: Don't use device_ids for multi-GPU models or when model spans devices + # Ensure model is on GPU + try: + model_device = next(model.parameters()).device + if model_device.type == 'cpu': + local_rank = dist.get_rank() % torch.cuda.device_count() + model = model.to(f'cuda:{local_rank}') + except StopIteration: + pass # No parameters + + # Wrap with Float16Module for mixed precision (like Megatron's get_model) + if (config.fp16 or config.bf16) and not isinstance(model, Float16Module): + # Check if the inner model (for PEFT) needs wrapping + inner_model = model + if hasattr(model, 'base_model') and hasattr(model.base_model, 'model'): + inner_model = model.base_model.model + + # Only wrap if not already wrapped + if not isinstance(inner_model, Float16Module): + # For PEFT models, we can't easily wrap the inner model + # Just proceed without Float16Module + if not hasattr(model, 'base_model'): + model = Float16Module(config, model) + + # Create DDP config + ddp_config = DistributedDataParallelConfig( + grad_reduce_in_fp32=True, + overlap_grad_reduce=False, + use_distributed_optimizer=use_distributed_optimizer, ) - return wrapped_model, optimizer + # Wrap with MegatronDDP + try: + wrapped_model = MegatronDDP( + config=config, + ddp_config=ddp_config, + module=model, + ) + + # Broadcast params from data parallel src rank + # In torchrun mode, all ranks enter here simultaneously, so this works + wrapped_model.broadcast_params() + + return wrapped_model, optimizer + + except Exception as e: + import warnings + warnings.warn(f"Failed to wrap with Megatron DDP: {e}. Using unwrapped model.") + return model, optimizer def unwrap_model(self, model: nn.Module) -> nn.Module: """Unwrap the distributed model to get the base model. From 7ff8fc39d035ed1a82b00b2311e302bd314d669a Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 Jan 2026 17:04:38 +0800 Subject: [PATCH 03/22] local 4d --- cookbook/megatron/lora.py | 326 +++------ cookbook/megatron/lora_ray.py | 227 ++++++ .../loss/vocab_parallel_cross_entropy.py | 81 +-- src/twinkle/megatron/model/__init__.py | 2 + src/twinkle/megatron/model/bridge.py | 418 ++++++++++- src/twinkle/megatron/model/initializer.py | 27 +- src/twinkle/megatron/model/swift_bridge.py | 673 ------------------ src/twinkle/megatron/worker.py | 368 ++++++++++ src/twinkle/model/megatron.py | 287 ++++++-- tests/test_parallelism.py | 405 +++++++++++ 10 files changed, 1780 insertions(+), 1034 deletions(-) create mode 100644 cookbook/megatron/lora_ray.py delete mode 100644 src/twinkle/megatron/model/swift_bridge.py create mode 100644 src/twinkle/megatron/worker.py create mode 100644 tests/test_parallelism.py diff --git a/cookbook/megatron/lora.py b/cookbook/megatron/lora.py index 907828e5..04a9fcdd 100644 --- a/cookbook/megatron/lora.py +++ b/cookbook/megatron/lora.py @@ -1,30 +1,23 @@ # Copyright (c) twinkle authors. All rights reserved. -"""Megatron-Core LoRA training example with full 4D parallelism. +"""Megatron-Core LoRA training example. -This example demonstrates LoRA fine-tuning using Megatron-Core backend. -Supports Tensor Parallel (TP), Pipeline Parallel (PP), Context Parallel (CP), -and Data Parallel (DP). DP is automatically calculated from WORLD_SIZE. +Usage (8 GPUs with TP2 PP2 CP2): + torchrun --nproc_per_node=8 cookbook/megatron/lora.py --tp_size 2 --pp_size 2 --cp_size 2 -The script uses Megatron's get_forward_backward_func() for unified pipeline -scheduling, ensuring proper multi-tenant isolation through process groups. +Usage (4 GPUs with TP2 PP2): + torchrun --nproc_per_node=4 cookbook/megatron/lora.py --tp_size 2 --pp_size 2 -TODO: Add Expert Parallel (EP) support for MoE models. - -Usage (8 GPUs with CP2 PP2 TP2, DP auto-calculated as 1): - torchrun --nproc_per_node=8 cookbook/megatron/lora.py \ - --tp_size 2 --pp_size 2 --cp_size 2 - -Usage (4 GPUs with TP2, DP auto-calculated as 2): - torchrun --nproc_per_node=4 cookbook/megatron/lora.py --tp_size 2 - -Usage (single GPU for debugging): - torchrun --nproc_per_node=1 cookbook/megatron/lora.py - -Note: WORLD_SIZE is automatically detected from torchrun, no need to specify it twice. +Usage (single GPU): + torchrun --nproc_per_node=1 cookbook/megatron/lora.py --tp_size 1 --pp_size 1 """ import argparse import os +# CRITICAL: Set CUDA device before any CUDA imports to ensure correct device placement +import torch +LOCAL_RANK = int(os.environ.get('LOCAL_RANK', '0')) +torch.cuda.set_device(LOCAL_RANK) + import numpy as np from peft import LoraConfig from torch.optim import AdamW @@ -34,243 +27,128 @@ from twinkle import get_device_placement, get_logger, DeviceMesh, DeviceGroup, Platform from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta -from twinkle.loss import MegatronCrossEntropyLoss +from twinkle.loss import VocabParallelCrossEntropyLoss from twinkle.model import MegatronModel from twinkle.processor import InputProcessor logger = get_logger() - -def parse_args(): - parser = argparse.ArgumentParser(description='Megatron LoRA Training with 4D Parallelism') - - # Mode selection - parser.add_argument('--mode', type=str, default='local', - choices=['local', 'ray'], - help='Distributed mode: local (torchrun) or ray') - - # Number of GPUs - parser.add_argument('--nproc_per_node', type=int, default=4, - help='Total number of GPUs') - - # 4D Parallelism configuration - # Total GPUs = DP * CP * PP * TP (DP is auto-calculated) - # TODO: Add EP (Expert Parallel) for MoE models - parser.add_argument('--tp_size', type=int, default=1, - help='Tensor Parallel size (splits model layers horizontally)') - parser.add_argument('--pp_size', type=int, default=1, - help='Pipeline Parallel size (splits model layers vertically)') - parser.add_argument('--cp_size', type=int, default=1, - help='Context Parallel size (splits sequence across GPUs)') - # Note: DP size is automatically calculated as: WORLD_SIZE / (TP * PP * CP) - - # Sequence parallel (usually enabled with TP > 1) - parser.add_argument('--sequence_parallel', action='store_true', default=False, - help='Enable sequence parallelism (recommended when TP > 1)') - - # Max steps for quick testing - parser.add_argument('--max_steps', type=int, default=None, - help='Maximum training steps (for testing)') - - args = parser.parse_args() - - # Auto-detect world size from environment (set by torchrun) - world_size = int(os.environ.get('WORLD_SIZE', '1')) - args.world_size = world_size - - # Auto-calculate DP size from total GPUs and model parallel sizes - model_parallel_size = args.tp_size * args.pp_size * args.cp_size - if world_size % model_parallel_size != 0: - raise ValueError( - f'WORLD_SIZE ({world_size}) must be divisible by ' - f'TP({args.tp_size}) * PP({args.pp_size}) * CP({args.cp_size}) = {model_parallel_size}' - ) - args.dp_size = world_size // model_parallel_size - - logger.info(f'4D Parallelism config: DP={args.dp_size} (auto), TP={args.tp_size}, ' - f'PP={args.pp_size}, CP={args.cp_size}, Total GPUs={world_size}') - - return args - - -def create_device_mesh(args) -> DeviceMesh: - """Create device mesh for Megatron 4D parallelism. - - Megatron uses the following parallelism hierarchy (outer to inner): - - Data Parallel (DP): Replicates model, splits data (auto-calculated) - - Context Parallel (CP): Splits sequence across GPUs - - Pipeline Parallel (PP): Splits layers across stages - - Tensor Parallel (TP): Splits layers horizontally - - TODO: Add Expert Parallel (EP) dimension for MoE models. - - Mesh shape: (dp, cp, pp, tp) - """ - total_gpus = args.world_size - - # Create mesh with shape (dp, cp, pp, tp) - mesh = np.arange(total_gpus).reshape(args.dp_size, args.cp_size, args.pp_size, args.tp_size) - - device_mesh = DeviceMesh( - device_type='cuda', - mesh=mesh, - mesh_dim_names=('dp', 'cp', 'pp', 'tp'), +# Parse arguments +parser = argparse.ArgumentParser() +parser.add_argument('--tp_size', type=int, default=1) +parser.add_argument('--pp_size', type=int, default=1) +parser.add_argument('--cp_size', type=int, default=1) +parser.add_argument('--max_steps', type=int, default=None) +args = parser.parse_args() + +# Get parallelism config +WORLD_SIZE = int(os.environ.get('WORLD_SIZE', '1')) +TP_SIZE = args.tp_size +PP_SIZE = args.pp_size +CP_SIZE = args.cp_size +DP_SIZE = WORLD_SIZE // (TP_SIZE * PP_SIZE * CP_SIZE) + +# Device mesh: Match Megatron's order "tp-cp-ep-dp-pp" from innermost to outermost +# For mesh shape, we reverse the order: (pp, dp, cp, tp) where rightmost is innermost +# This ensures DP groups match between twinkle and Megatron +device_mesh = DeviceMesh( + device_type='cuda', + mesh=np.arange(WORLD_SIZE).reshape(PP_SIZE, DP_SIZE, CP_SIZE, TP_SIZE), + mesh_dim_names=('pp', 'dp', 'cp', 'tp'), +) + +device_group = [ + DeviceGroup( + name='model', + ranks=list(range(WORLD_SIZE)), + device_type=Platform.get_platform().device_prefix(), ) - return device_mesh - +] -def create_device_group(args): - """Create device group for model placement.""" - device_group = [ - DeviceGroup( - name='model', - ranks=list(range(args.world_size)), - device_type=Platform.get_platform().device_prefix(), - ) - ] - return device_group +twinkle.initialize( + mode='local', + nproc_per_node=WORLD_SIZE, + groups=device_group, + global_device_mesh=device_mesh, + lazy_collect=False, +) def create_dataset(): - """Create and preprocess dataset.""" dataset = Dataset(dataset_meta=DatasetMeta('ms://modelscope/competition_math')) dataset.set_template('Qwen3Template', model_id='ms://Qwen/Qwen2.5-7B-Instruct') dataset.map('CompetitionMathProcessor') - # IMPORTANT: Use load_from_cache_file=False to avoid stale cache with incorrect labels dataset.encode(batched=True, load_from_cache_file=False) return dataset -def train(args): - """Main training function with 4D parallelism support.""" - # Create dataloader - dataloader = DataLoader(dataset=create_dataset, batch_size=8) - - # Create Megatron model with 4D parallelism - # TODO: Add expert_model_parallel_size for MoE models +def train(): + # Use smaller batch size for single GPU to avoid OOM + batch_size = 2 if WORLD_SIZE == 1 else 8 + dataloader = DataLoader(dataset=create_dataset, batch_size=batch_size) + model = MegatronModel( pretrained_model_name_or_path='ms://Qwen/Qwen2.5-7B-Instruct', - tensor_model_parallel_size=args.tp_size, - pipeline_model_parallel_size=args.pp_size, - context_parallel_size=args.cp_size, - sequence_parallel=args.sequence_parallel, + tensor_model_parallel_size=TP_SIZE, + pipeline_model_parallel_size=PP_SIZE, + context_parallel_size=CP_SIZE, mixed_precision='bf16', + # Use 'full' recompute for single GPU to reduce memory usage + recompute_granularity='full' if WORLD_SIZE <= 2 else 'selective', ) - - # Set template, processor, loss on DEFAULT adapter FIRST - # These will be copied when adding LoRA adapter - model.set_template('Qwen3Template') - model.set_processor(InputProcessor, padding_side='right') - model.set_loss(MegatronCrossEntropyLoss) - - # Configure LoRA adapter - lora_config = LoraConfig( - target_modules='all-linear' - ) - - # Add LoRA adapter - template, processor, loss_instance will be copied from default - model.add_adapter_to_model('lora', lora_config, gradient_accumulation_steps=16) - - # Set optimizer and scheduler for LoRA adapter (must be after add_adapter_to_model) - model.set_optimizer(AdamW, lr=1e-4, adapter_name='lora') - model.set_lr_scheduler(LinearLR, adapter_name='lora') - - # Print training configuration - logger.info(get_device_placement()) - logger.info(model.get_train_configs(adapter_name='lora')) - - # Training loop - gradient_accumulation_steps = 16 - optimizer_step = 0 - max_steps = args.max_steps - - for step, batch in enumerate(dataloader): - output = model.forward_backward(inputs=batch, adapter_name='lora') - - # Only perform optimizer step at gradient accumulation boundary - if (step + 1) % gradient_accumulation_steps == 0: - optimizer_step = (step + 1) // gradient_accumulation_steps - - # Log loss - logger.info(f'Current is step {optimizer_step}, loss: {output}') - - # Gradient clipping and optimizer step - model.clip_grad_norm(1.0, adapter_name='lora') - model.step(adapter_name='lora') - model.zero_grad(adapter_name='lora') - model.lr_step(adapter_name='lora') - - # Save checkpoint every 100 optimizer steps - if optimizer_step % 100 == 0: - model.save('./output/megatron_lora', adapter_name='lora') - - # Check max_steps for early stopping (for testing) - if max_steps is not None and optimizer_step >= max_steps: - logger.info(f'Reached max_steps ({max_steps}), stopping training.') - break - - # Save final checkpoint - logger.info(f'Training completed! Final step: {optimizer_step}') - model.save('./output/megatron_lora', adapter_name='lora') + lora_config = LoraConfig(target_modules='all-linear') -def main(): - args = parse_args() - - # Set TWINKLE_MODE environment variable for strategy to detect - os.environ['TWINKLE_MODE'] = args.mode - - # Create device mesh and group - device_mesh = create_device_mesh(args) - device_group = create_device_group(args) - - # Initialize twinkle with specified mode - twinkle.initialize( - mode=args.mode, - nproc_per_node=args.world_size, - groups=device_group, - global_device_mesh=device_mesh, - lazy_collect=False, - ) - - try: - # Start training - train(args) - finally: - # Clean up distributed process groups - cleanup_distributed() + # Use 'lora' as adapter_name and pass it consistently to all methods + adapter_name = 'lora' + model.add_adapter_to_model(adapter_name, lora_config, gradient_accumulation_steps=16) + model.set_template('Qwen3Template', adapter_name=adapter_name) + model.set_processor(InputProcessor, padding_side='right', adapter_name=adapter_name) + # Note: For MegatronModel, loss is computed internally by Megatron. + # set_loss() is optional and mainly for API compatibility. + model.set_loss(VocabParallelCrossEntropyLoss, adapter_name=adapter_name) + model.set_optimizer(AdamW, lr=1e-4, adapter_name=adapter_name) + model.set_lr_scheduler(LinearLR, adapter_name=adapter_name) + logger.info(get_device_placement()) + logger.info(model.get_train_configs(adapter_name=adapter_name)) -def cleanup_distributed(): - """Clean up all distributed process groups.""" - import torch + for step, batch in enumerate(dataloader): + output = model.forward_backward(inputs=batch, adapter_name=adapter_name) + if step % 16 == 0: + logger.info(f'Step {step // 16}, loss: {output}') + model.clip_grad_norm(1.0, adapter_name=adapter_name) + model.step(adapter_name=adapter_name) + model.zero_grad(adapter_name=adapter_name) + model.lr_step(adapter_name=adapter_name) + if step % 100 == 0: + model.save('./output/megatron_lora', adapter_name=adapter_name) + # Early stop for testing + if args.max_steps and step >= args.max_steps * 16: + logger.info(f'Reached max_steps ({args.max_steps}), stopping.') + break + + logger.info('Training completed!') + + +def cleanup(): + """Clean up distributed resources.""" import torch.distributed as dist - - # Synchronize all processes before cleanup - if dist.is_initialized(): - try: - # Use barrier with timeout to prevent hanging - dist.barrier() - except Exception as e: - logger.warning(f'Barrier failed during cleanup: {e}') - - try: - dist.destroy_process_group() - except Exception as e: - logger.warning(f'Failed to destroy process group: {e}') - - # Also clean up Megatron's parallel state if initialized try: + # Barrier to ensure all processes are synchronized before cleanup + if dist.is_initialized(): + dist.barrier() from megatron.core import parallel_state as mpu if mpu.is_initialized(): mpu.destroy_model_parallel() except Exception: pass - - # Clear CUDA cache - if torch.cuda.is_available(): - torch.cuda.empty_cache() + if dist.is_initialized(): + dist.destroy_process_group() if __name__ == '__main__': - main() + try: + train() + finally: + cleanup() diff --git a/cookbook/megatron/lora_ray.py b/cookbook/megatron/lora_ray.py new file mode 100644 index 00000000..85e58b78 --- /dev/null +++ b/cookbook/megatron/lora_ray.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python +# Copyright (c) twinkle authors. All rights reserved. +"""Megatron-Core LoRA training in Ray mode. + +This script uses MegatronWorkerGroup for Ray-based distributed training +with proper Megatron collective operations support. + +NOTE: PP > 1 is REQUIRED for training. PP=1 has known gradient flow issues +with PEFT/LoRA and Megatron's forward_backward_no_pipelining. + +Usage: + # TP=2, PP=2 (4 GPUs) - RECOMMENDED + python cookbook/megatron/lora_ray.py --tp_size 2 --pp_size 2 --num_gpus 4 + + # PP=4, TP=1 (4 GPUs) + python cookbook/megatron/lora_ray.py --tp_size 1 --pp_size 4 --num_gpus 4 + + # PP=2, TP=1 (2 GPUs) + python cookbook/megatron/lora_ray.py --tp_size 1 --pp_size 2 --num_gpus 2 +""" +import argparse +import os +import sys + +# Add paths +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src')) +megatron_path = os.environ.get('MEGATRON_LM_PATH', '/mnt/nas2/hujinghan.hjh/Megatron-LM') +sys.path.insert(0, megatron_path) + +import ray +import torch +import numpy as np + +from twinkle import get_logger +from twinkle.megatron.worker import MegatronWorkerGroup + +logger = get_logger() + + +def create_dataset(): + """Create and prepare the training dataset - same as local mode.""" + from twinkle.dataset import Dataset, DatasetMeta + + dataset = Dataset(dataset_meta=DatasetMeta('ms://modelscope/competition_math')) + dataset.set_template('Qwen3Template', model_id='ms://Qwen/Qwen2.5-0.5B-Instruct') + dataset.map('CompetitionMathProcessor') + dataset.encode(batched=True, load_from_cache_file=False) + return dataset + + +def collate_batch(samples, batch_size: int, max_seq_len: int = 512): + """Collate samples into a batch with padding.""" + # Take batch_size samples + samples = samples[:batch_size] + + # Get max length in batch (capped at max_seq_len) + max_len = min(max(len(s['input_ids']) for s in samples), max_seq_len) + + input_ids_list = [] + attention_mask_list = [] + labels_list = [] + + for s in samples: + ids = s['input_ids'][:max_len] + pad_len = max_len - len(ids) + + input_ids_list.append(ids + [0] * pad_len) + attention_mask_list.append([1] * len(ids) + [0] * pad_len) + + # Labels: use -100 for padding + labels = s.get('labels', ids)[:max_len] + labels_list.append(labels + [-100] * pad_len) + + return { + 'input_ids': torch.tensor(input_ids_list, dtype=torch.long), + 'attention_mask': torch.tensor(attention_mask_list, dtype=torch.long), + 'labels': torch.tensor(labels_list, dtype=torch.long), + } + + +def main(): + parser = argparse.ArgumentParser(description='Megatron LoRA training in Ray mode') + parser.add_argument('--tp_size', type=int, default=2, help='Tensor parallel size') + parser.add_argument('--pp_size', type=int, default=2, help='Pipeline parallel size (must be > 1 for training)') + parser.add_argument('--cp_size', type=int, default=1, help='Context parallel size') + parser.add_argument('--num_gpus', type=int, default=4, help='Number of GPUs') + parser.add_argument('--max_steps', type=int, default=10, help='Max training steps') + parser.add_argument('--batch_size', type=int, default=2, help='Batch size per step') + parser.add_argument('--max_seq_len', type=int, default=512, help='Max sequence length') + parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate') + parser.add_argument('--model', type=str, default='ms://Qwen/Qwen2.5-0.5B-Instruct', + help='Model path or ID') + parser.add_argument('--lora_r', type=int, default=8, help='LoRA rank') + args = parser.parse_args() + + # Validate parallelism config + expected_gpus = args.tp_size * args.pp_size * args.cp_size + if args.num_gpus < expected_gpus: + logger.error(f"Need at least {expected_gpus} GPUs for TP={args.tp_size}, " + f"PP={args.pp_size}, CP={args.cp_size}, but only {args.num_gpus} provided") + return 1 + + # Prepare dataset first (on driver, before Ray workers) + logger.info("Preparing dataset...") + dataset = create_dataset() + samples = [dataset[i] for i in range(min(len(dataset), args.max_steps * args.batch_size + 100))] + logger.info(f"Loaded {len(samples)} samples") + + # Initialize Ray + if not ray.is_initialized(): + ray.init(ignore_reinit_error=True) + + logger.info(f"Ray initialized with {args.num_gpus} GPUs") + logger.info(f"Config: TP={args.tp_size}, PP={args.pp_size}, CP={args.cp_size}") + + # Create worker group + worker_group = MegatronWorkerGroup( + world_size=args.num_gpus, + tp_size=args.tp_size, + pp_size=args.pp_size, + cp_size=args.cp_size, + ) + + try: + # Initialize workers + logger.info("Initializing workers...") + results = worker_group.init_all() + if not all(results): + raise RuntimeError("Worker initialization failed") + + # Create model + logger.info(f"Loading model: {args.model}") + results = worker_group.create_model_all( + pretrained_model_name_or_path=args.model, + mixed_precision='bf16', + recompute_granularity='full', + ) + if not all(results): + raise RuntimeError("Model creation failed") + + # Add LoRA with Megatron layer names + logger.info("Adding LoRA adapters...") + lora_config = { + 'target_modules': ['linear_qkv', 'linear_proj', 'linear_fc1', 'linear_fc2'], + 'r': args.lora_r, + 'lora_alpha': args.lora_r, + 'lora_dropout': 0.0, + } + results = worker_group.add_lora_all(lora_config) + if not all(results): + raise RuntimeError("LoRA addition failed") + + # Set optimizer + logger.info(f"Setting optimizer with lr={args.lr}") + results = worker_group.set_optimizer_all(lr=args.lr) + if not all(results): + raise RuntimeError("Optimizer setup failed") + + # Training loop + logger.info(f"Starting training for {args.max_steps} steps...") + losses = [] + + # Use same batch for all steps to verify loss decreases (overfitting test) + fixed_batch = collate_batch(samples[:args.batch_size], args.batch_size, args.max_seq_len) + + for step in range(args.max_steps): + batch = fixed_batch + + # Forward-backward + step_losses = worker_group.forward_backward_all(batch) + + # Get valid loss (non-zero from last PP stage) + valid_losses = [l for l in step_losses if l > 0] + avg_loss = np.mean(valid_losses) if valid_losses else 0.0 + losses.append(avg_loss) + + logger.info(f"Step {step:3d}/{args.max_steps}, loss: {avg_loss:.4f}") + + # Optimizer step + worker_group.step_all() + + # Check loss trend + logger.info("=" * 60) + logger.info("Training Summary:") + logger.info(f" Initial loss: {losses[0]:.4f}") + logger.info(f" Final loss: {losses[-1]:.4f}") + logger.info(f" Loss change: {losses[-1] - losses[0]:.4f}") + + # Validation checks (aligned with local mode expectations) + initial_ok = losses[0] < 3 # Real data should have initial loss < 3 + decreasing = losses[-1] < losses[0] # Should decrease over training + + if initial_ok: + logger.info("✓ Initial loss is reasonable (< 3)") + else: + logger.warning(f"✗ Initial loss {losses[0]:.4f} is too high (expected < 3)") + + if decreasing: + logger.info("✓ Loss is decreasing (training is working)") + else: + logger.warning("✗ Loss is not decreasing") + + logger.info("=" * 60) + logger.info("Training completed!") + + return 0 if (initial_ok and decreasing) else 1 + + except Exception as e: + logger.error(f"Error: {e}") + import traceback + traceback.print_exc() + return 1 + + finally: + logger.info("Cleaning up...") + try: + worker_group.cleanup_all() + except Exception: + pass + try: + worker_group.shutdown() + except Exception: + pass + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/src/twinkle/loss/vocab_parallel_cross_entropy.py b/src/twinkle/loss/vocab_parallel_cross_entropy.py index 5a47bbfb..19deaf37 100644 --- a/src/twinkle/loss/vocab_parallel_cross_entropy.py +++ b/src/twinkle/loss/vocab_parallel_cross_entropy.py @@ -1,88 +1,47 @@ # Copyright (c) twinkle authors. All rights reserved. -"""Vocabulary-parallel cross entropy loss for Megatron TP training. - -When using Tensor Parallelism, the vocabulary dimension is sharded across TP ranks. -Standard CrossEntropyLoss will fail because labels may fall outside the local -vocab partition. This module provides vocab-parallel loss computation. -""" -from typing import Optional - +"""Vocab-parallel cross entropy loss for Megatron backend with Tensor Parallelism.""" import torch -import torch.nn.functional as F from .base import Loss -try: - from megatron.core import parallel_state as mpu - from megatron.core import tensor_parallel - MEGATRON_AVAILABLE = True -except ImportError: - MEGATRON_AVAILABLE = False - class VocabParallelCrossEntropyLoss(Loss): - """Cross entropy loss that handles vocabulary parallelism in Megatron. + """Vocab-parallel cross entropy loss for Megatron training with TP > 1. - When using TP (Tensor Parallelism), the vocabulary is sharded across TP ranks. - This loss uses Megatron's vocab_parallel_cross_entropy which correctly handles - the distributed computation. + This loss uses Megatron's tensor_parallel.vocab_parallel_cross_entropy to + correctly compute cross entropy when vocabulary is sharded across TP ranks. NOTE: Labels are expected to be pre-shifted by the template (using np.roll). This loss does NOT perform additional shifting. - Fallback: When Megatron is not available or TP=1, uses standard CrossEntropyLoss. + Args: + ignore_index: The label value to ignore when computing loss. Default: -100. """ + def __init__(self, ignore_index: int = -100): + super().__init__() + self.ignore_index = ignore_index + def __call__(self, inputs, outputs, **kwargs): + from megatron.core import tensor_parallel + logits = outputs['logits'] labels = inputs['labels'] - shift_logits = logits[:, :-1, :].contiguous() - shift_labels = labels[:, :-1].contiguous() - - if not MEGATRON_AVAILABLE: - # Fallback to standard loss - logits_2d = shift_logits.view(-1, shift_logits.shape[-1]) - labels_1d = shift_labels.view(-1) - return F.cross_entropy(logits_2d, labels_1d, ignore_index=-100) - - tp_size = mpu.get_tensor_model_parallel_world_size() - - if tp_size == 1: - # No TP, use standard cross entropy - logits_2d = shift_logits.view(-1, shift_logits.shape[-1]) - labels_1d = shift_labels.view(-1) - return F.cross_entropy(logits_2d, labels_1d, ignore_index=-100) + # Transpose: [batch, seq, vocab] -> [seq, batch, vocab] + logits_sbv = logits.transpose(0, 1).contiguous() + labels_sb = labels.transpose(0, 1).contiguous() - # Use Megatron's vocab-parallel cross entropy - # Megatron expects [seq, batch, vocab] format for logits - # and [seq, batch] for labels - - # Transpose logits: [batch, seq-1, vocab] -> [seq-1, batch, vocab] - logits_sbv = shift_logits.transpose(0, 1).contiguous() - - # Transpose labels: [batch, seq-1] -> [seq-1, batch] - labels_sb = shift_labels.transpose(0, 1).contiguous() - - # Megatron's vocab_parallel_cross_entropy handles the TP sharding correctly - # It returns per-token loss of shape [seq-1, batch] + # Compute vocab-parallel cross entropy per_token_loss = tensor_parallel.vocab_parallel_cross_entropy(logits_sbv, labels_sb) - - # Transpose back: [seq-1, batch] -> [batch, seq-1] per_token_loss = per_token_loss.transpose(0, 1).contiguous() - # Apply loss mask (ignore labels == -100) - loss_mask = (shift_labels != -100).float() - - # Compute mean loss (only over non-masked positions) + # Apply loss mask + loss_mask = (labels != self.ignore_index).float() loss = (per_token_loss * loss_mask).sum() / loss_mask.sum().clamp(min=1) return loss -class MegatronCrossEntropyLoss(VocabParallelCrossEntropyLoss): - """Alias for VocabParallelCrossEntropyLoss. - - Use this when training with Megatron backend and TP > 1. - """ - pass +# Alias for backward compatibility +MegatronCrossEntropyLoss = VocabParallelCrossEntropyLoss diff --git a/src/twinkle/megatron/model/__init__.py b/src/twinkle/megatron/model/__init__.py index a330f603..a4e289f2 100644 --- a/src/twinkle/megatron/model/__init__.py +++ b/src/twinkle/megatron/model/__init__.py @@ -7,6 +7,7 @@ from .bridge import ( # Main classes TwinkleBridgeAdapter, + TwinkleBridgeInitializer, TwinkleGPTBridge, BridgeConfig, SafetensorLoader, @@ -28,6 +29,7 @@ __all__ = [ # Bridge classes 'TwinkleBridgeAdapter', + 'TwinkleBridgeInitializer', 'TwinkleGPTBridge', 'BridgeConfig', 'SafetensorLoader', diff --git a/src/twinkle/megatron/model/bridge.py b/src/twinkle/megatron/model/bridge.py index a75dce7b..b9a64f97 100644 --- a/src/twinkle/megatron/model/bridge.py +++ b/src/twinkle/megatron/model/bridge.py @@ -24,7 +24,7 @@ import torch.nn.functional as F import torch.distributed as dist from tqdm import tqdm - +from twinkle.hub import HubOperation try: from megatron.core import parallel_state as mpu MEGATRON_AVAILABLE = True @@ -51,9 +51,28 @@ def deep_getattr(obj, attr: str, default=None): def is_last_rank() -> bool: - """Check if current process is the last rank.""" + """Check if current process is the last rank for writing. + + For DP > 1, we want only DP rank 0 to write to avoid conflicts. + For PP, we want the last PP stage. + For TP, all TP ranks participate in gather, but only one writes. + """ if not dist.is_initialized(): return True + + try: + from megatron.core import parallel_state as mpu + if mpu.is_initialized(): + # Only DP rank 0 writes + dp_rank = mpu.get_data_parallel_rank() + if dp_rank != 0: + return False + # For PP, only last stage needs to write certain weights + # (handled separately in export_weights) + return True + except (ImportError, AssertionError): + pass + return dist.get_rank() == dist.get_world_size() - 1 @@ -1152,15 +1171,26 @@ def save_weights( mg_models: Megatron model(s) to save. output_dir: Directory to save weights. is_peft_format: Whether saving in PEFT format. + + Note: + For DP > 1, only DP rank 0 writes to disk. All ranks participate + in tensor gather operations for TP. """ torch.cuda.empty_cache() - saver = StreamingSafetensorSaver( - save_dir=output_dir, - max_shard_size=self.config.max_shard_size, - is_peft_format=is_peft_format, - ) + # Determine if this rank should write + should_write = is_last_rank() + # Only the writing rank creates the saver + saver = None + if should_write: + saver = StreamingSafetensorSaver( + save_dir=output_dir, + max_shard_size=self.config.max_shard_size, + is_peft_format=is_peft_format, + ) + + # All ranks participate in export (needed for TP gather) for key, tensor in self.export_weights( mg_models, target_device='cpu', @@ -1168,12 +1198,14 @@ def save_weights( is_peft_format=is_peft_format, tqdm_desc='Saving: ', ): - saver.add_tensor(key, tensor) + if saver is not None and tensor is not None: + saver.add_tensor(key, tensor) - saver.finalize() + if saver is not None: + saver.finalize() - # Save config on last rank - if is_last_rank(): + # Save config on writing rank only + if should_write: if is_peft_format and not isinstance(mg_models, (list, tuple)): mg_models = [mg_models] @@ -1188,6 +1220,7 @@ def save_weights( self.hf_config.vocab_size = self.config.padded_vocab_size self.hf_config.save_pretrained(output_dir) + # Synchronize all ranks before continuing if dist.is_initialized(): dist.barrier() @@ -1261,6 +1294,369 @@ def save_weights( bridge.save_weights(mg_models, output_dir, is_peft_format) +class TwinkleBridgeInitializer: + """ + Megatron model initializer. + + This class provides complete model initialization flow including: + - Megatron parallel state initialization + - Model creation from HuggingFace config + - Weight loading using TwinkleGPTBridge + + Example: + initializer = TwinkleBridgeInitializer( + tp_size=2, + pp_size=1, + params_dtype=torch.bfloat16, + ) + model = initializer.create_model('Qwen/Qwen2.5-7B-Instruct') + """ + + def __init__( + self, + tp_size: int = 1, + pp_size: int = 1, + cp_size: int = 1, + ep_size: int = 1, + etp_size: Optional[int] = None, + params_dtype=None, + use_cpu_initialization: bool = False, + attention_backend: str = 'flash', + recompute_granularity: Optional[str] = 'selective', + recompute_modules: Optional[list] = None, + recompute_method: Optional[str] = None, + recompute_num_layers: Optional[int] = None, + ): + """Initialize TwinkleBridgeInitializer. + + Args: + tp_size: Tensor parallel size. + pp_size: Pipeline parallel size. + cp_size: Context parallel size. + ep_size: Expert parallel size. + etp_size: Expert tensor parallel size. + params_dtype: Parameter dtype (default: torch.bfloat16). + use_cpu_initialization: Initialize on CPU first. + attention_backend: Attention backend. + recompute_granularity: Activation recomputation strategy. + 'selective' (default): Only recompute core attention (memory efficient). + 'full': Recompute entire transformer layer (most memory efficient). + None: No recomputation (fastest but highest memory). + recompute_modules: Modules to recompute when using 'selective' granularity. + Default: ['core_attn'] for efficient memory/compute trade-off. + recompute_method: Method for full recompute ('uniform' or 'block'). + Required when recompute_granularity='full'. + recompute_num_layers: Number of layers to recompute for 'full' mode. + Required when recompute_granularity='full'. + """ + self.tp_size = tp_size + self.pp_size = pp_size + self.cp_size = cp_size + self.ep_size = ep_size + self.etp_size = etp_size or tp_size + self.params_dtype = params_dtype if params_dtype is not None else torch.bfloat16 + self.use_cpu_initialization = use_cpu_initialization + self.attention_backend = attention_backend + self.recompute_granularity = recompute_granularity + self.recompute_modules = recompute_modules or ['core_attn'] + self.recompute_method = recompute_method + self.recompute_num_layers = recompute_num_layers + + self._model = None + self._bridge = None + self._hf_config = None + self._model_path = None + + def _download_model(self, model_path: str) -> str: + """Download model if it's a model ID.""" + if os.path.isdir(model_path): + return model_path + + try: + from modelscope import snapshot_download + return snapshot_download(model_path) + except ImportError: + from huggingface_hub import snapshot_download + return snapshot_download(model_path) + + def _initialize_megatron(self, hf_config: Any = None): + """Initialize Megatron parallel state. + + This sets up the required process groups for tensor, pipeline, + and data parallelism using Megatron's parallel state module directly. + + Args: + hf_config: Optional HuggingFace config for additional model parameters. + """ + import torch.distributed as dist + from megatron.core import parallel_state as mpu + from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed + + # Check if already initialized + try: + if mpu.is_initialized(): + return + except AssertionError: + pass + + # Initialize distributed if not already + if not dist.is_initialized(): + dist.init_process_group(backend='nccl') + + # Initialize Megatron parallel state directly + mpu.initialize_model_parallel( + tensor_model_parallel_size=self.tp_size, + pipeline_model_parallel_size=self.pp_size, + context_parallel_size=self.cp_size, + expert_model_parallel_size=self.ep_size, + ) + + # Initialize CUDA RNG tracker for tensor parallel random states + # This is required when use_cpu_initialization=False (GPU initialization) + model_parallel_cuda_manual_seed(42) + + def _create_model_from_config( + self, + hf_config: Any, + padded_vocab_size: int, + ) -> nn.Module: + """Create Megatron GPT model from HuggingFace config. + + Args: + hf_config: HuggingFace model configuration. + padded_vocab_size: Padded vocabulary size. + + Returns: + Megatron GPT model. + """ + import torch.distributed as dist + from megatron.core import parallel_state as mpu + from megatron.core.transformer import TransformerConfig + from megatron.core.models.gpt import GPTModel + from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_layer_with_transformer_engine_spec, + ) + + # Convert HF config to Megatron config + from ..utils import convert_hf_config + mg_config_dict = convert_hf_config(hf_config) + + # Build TransformerConfig + num_attention_heads = mg_config_dict['num_attention_heads'] + num_query_groups = mg_config_dict.get('num_query_groups', num_attention_heads) + num_layers = mg_config_dict['num_layers'] + + # Configure activation recomputation + recompute_method = self.recompute_method + recompute_num_layers = self.recompute_num_layers + + # Auto-configure for 'full' recomputation if not specified + if self.recompute_granularity == 'full': + if recompute_method is None: + recompute_method = 'uniform' + if recompute_num_layers is None: + # Recompute all layers for maximum memory savings + recompute_num_layers = num_layers // self.pp_size + + # Create finalize_model_grads function for DP gradient synchronization + # Megatron's native finalize_model_grads requires DDP-wrapped models with ddp_config. + # For PEFT/LoRA models, we use a custom implementation that handles non-DDP models. + from megatron.core.distributed import finalize_model_grads as _native_finalize_model_grads + + def finalize_model_grads_for_lora(model, num_tokens=None, pg_collection=None): + """Finalize model grads that handles both DDP and PEFT/LoRA models. + + For DDP-wrapped models: Delegates to Megatron's native finalize_model_grads + For PEFT/LoRA models: Manually all-reduce gradients across DP ranks + + This is necessary because PEFT models don't have ddp_config attribute + that Megatron's native implementation expects. + """ + from megatron.core import parallel_state as mpu + + # Check if model is DDP-wrapped (has ddp_config) + if hasattr(model[0], 'ddp_config'): + # Use native implementation for DDP models + return _native_finalize_model_grads(model, num_tokens, pg_collection) + + # For PEFT/LoRA models, call finish_grad_sync on each chunk + # The model should have finish_grad_sync added by MegatronModel.add_adapter_to_model + for model_chunk in model: + if hasattr(model_chunk, 'finish_grad_sync'): + model_chunk.finish_grad_sync() + + config = TransformerConfig( + num_layers=num_layers, + hidden_size=mg_config_dict['hidden_size'], + num_attention_heads=num_attention_heads, + num_query_groups=num_query_groups, + ffn_hidden_size=mg_config_dict.get('ffn_hidden_size', 4 * mg_config_dict['hidden_size']), + tensor_model_parallel_size=self.tp_size, + pipeline_model_parallel_size=self.pp_size, + context_parallel_size=self.cp_size, + expert_model_parallel_size=self.ep_size, + params_dtype=self.params_dtype, + pipeline_dtype=self.params_dtype, # Required when using pipeline parallelism + use_cpu_initialization=self.use_cpu_initialization, + add_qkv_bias=mg_config_dict.get('add_qkv_bias', False), + add_bias_linear=not mg_config_dict.get('disable_bias_linear', True), + gated_linear_unit=mg_config_dict.get('swiglu', True), + normalization='RMSNorm', + layernorm_epsilon=mg_config_dict.get('norm_epsilon', 1e-6), + qk_layernorm=mg_config_dict.get('qk_layernorm', False), + hidden_dropout=0.0, + attention_dropout=0.0, + # Activation recomputation for memory efficiency + recompute_granularity=self.recompute_granularity, + recompute_modules=self.recompute_modules if self.recompute_granularity == 'selective' else None, + recompute_method=recompute_method, + recompute_num_layers=recompute_num_layers, + # Critical: Set finalize_model_grads_func for DP gradient synchronization + # Uses custom wrapper that handles both DDP and PEFT/LoRA models + finalize_model_grads_func=finalize_model_grads_for_lora, + ) + + # Get layer spec + try: + layer_spec = get_gpt_layer_with_transformer_engine_spec( + num_experts=mg_config_dict.get('num_experts'), + moe_grouped_gemm=False, + qk_layernorm=mg_config_dict.get('qk_layernorm', False), + ) + except Exception: + from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec + layer_spec = get_gpt_layer_local_spec( + num_experts=mg_config_dict.get('num_experts'), + moe_grouped_gemm=False, + qk_layernorm=mg_config_dict.get('qk_layernorm', False), + ) + + # Create model + max_seq_length = getattr(hf_config, 'max_position_embeddings', 4096) + rotary_base = mg_config_dict.get('rotary_base', 10000) + + model = GPTModel( + config=config, + transformer_layer_spec=layer_spec, + vocab_size=padded_vocab_size, + max_sequence_length=max_seq_length, + pre_process=mpu.is_pipeline_first_stage(), + post_process=mpu.is_pipeline_last_stage(), + parallel_output=True, + share_embeddings_and_output_weights=getattr(hf_config, 'tie_word_embeddings', False), + position_embedding_type='rope', + rotary_base=rotary_base, + ) + + return model + + def _pad_vocab_size(self, vocab_size: int) -> int: + """Pad vocab size for tensor parallelism.""" + divisor = self.tp_size * 128 + return ((vocab_size + divisor - 1) // divisor) * divisor + + def create_model( + self, + model_path: str, + load_weights: bool = True, + ) -> nn.Module: + """Create Megatron model from HuggingFace checkpoint. + + Args: + model_path: Path to HuggingFace model or model ID. + load_weights: Whether to load weights. + + Returns: + Megatron model. + """ + from transformers import AutoConfig + + # Download model if needed + model_path = HubOperation.download_model(model_path) + self._model_path = model_path + + # Load HF config first (needed for initialization) + self._hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + + # Initialize Megatron parallel state with hf_config for proper args setup + self._initialize_megatron(self._hf_config) + + # Calculate padded vocab size + padded_vocab_size = self._pad_vocab_size(self._hf_config.vocab_size) + + # Create model + self._model = self._create_model_from_config(self._hf_config, padded_vocab_size) + + # Load weights + if load_weights: + bridge_adapter = TwinkleBridgeAdapter( + hf_config=self._hf_config, + tp_size=self.tp_size, + pp_size=self.pp_size, + ep_size=self.ep_size, + etp_size=self.etp_size, + model_path=model_path, + padded_vocab_size=padded_vocab_size, + ) + bridge_adapter.load_weights(self._model, model_path) + self._bridge = bridge_adapter._get_bridge() + + # Synchronize all ranks after model creation and weight loading + # This is critical for Pipeline Parallel to ensure all ranks are ready + # before any collective communication operations + if dist.is_initialized(): + dist.barrier() + + return self._model + + @property + def hf_config(self): + """Get the HuggingFace config.""" + return self._hf_config + + @property + def bridge(self): + """Get the bridge instance.""" + return self._bridge + + def load_weights(self, model: nn.Module, model_path: str): + """Load weights into an existing model. + + Args: + model: Megatron model. + model_path: Path to HuggingFace checkpoint. + """ + if self._bridge is None and self._hf_config is None: + raise ValueError("Must call create_model first") + + padded_vocab_size = self._pad_vocab_size(self._hf_config.vocab_size) + bridge_adapter = TwinkleBridgeAdapter( + hf_config=self._hf_config, + tp_size=self.tp_size, + pp_size=self.pp_size, + ep_size=self.ep_size, + model_path=model_path, + padded_vocab_size=padded_vocab_size, + ) + bridge_adapter.load_weights(model, model_path) + + def save_weights(self, models: Union[nn.Module, List[nn.Module]], output_dir: str, is_peft_format: bool = False): + """Save weights in HuggingFace format. + + Args: + models: Megatron model(s). + output_dir: Output directory. + is_peft_format: Whether to save in PEFT format. + """ + if self._bridge is None: + raise ValueError("Must load weights first") + + if not isinstance(models, (list, tuple)): + models = [models] + + self._bridge.save_weights(models, output_dir, is_peft_format=is_peft_format) + + # Legacy functions for backward compatibility def create_megatron_args(*args, **kwargs) -> SimpleNamespace: """Legacy function - use BridgeConfig instead.""" diff --git a/src/twinkle/megatron/model/initializer.py b/src/twinkle/megatron/model/initializer.py index 043486e2..ec3e9e1c 100644 --- a/src/twinkle/megatron/model/initializer.py +++ b/src/twinkle/megatron/model/initializer.py @@ -249,9 +249,7 @@ def load_from_hf( hf_config: Any, ) -> None: """Load HuggingFace checkpoint into Megatron model. - - TODO: Remove swift dependency. Implement independent bridge logic. - + Args: model: The Megatron model. hf_model_path: Path to HuggingFace checkpoint or model ID. @@ -261,29 +259,20 @@ def load_from_hf( # Resolve model path if it's a model ID (not a local path) if not os.path.isdir(hf_model_path): - # Try to download using HubOperation (twinkle's hub abstraction) - try: - from twinkle.hub import HubOperation - hf_model_path = HubOperation.download_model(hf_model_path) - except ImportError: - # Fallback to modelscope/huggingface snapshot download - try: - from modelscope import snapshot_download - hf_model_path = snapshot_download(hf_model_path) - except ImportError: - from huggingface_hub import snapshot_download - hf_model_path = snapshot_download(hf_model_path) - + from twinkle.hub import HubOperation + hf_model_path = HubOperation.download_model(hf_model_path) + # Calculate padded vocab size padded_vocab_size = self._pad_vocab_size(hf_config.vocab_size) - from .swift_bridge import create_bridge_adapter - adapter = create_bridge_adapter( + # Use TwinkleBridgeAdapter + from .bridge import TwinkleBridgeAdapter + adapter = TwinkleBridgeAdapter( hf_config=hf_config, tp_size=self.tp_size, pp_size=self.pp_size, ep_size=self.ep_size, - model_dir=hf_model_path, + model_path=hf_model_path, padded_vocab_size=padded_vocab_size, ) adapter.load_weights(model, hf_model_path) diff --git a/src/twinkle/megatron/model/swift_bridge.py b/src/twinkle/megatron/model/swift_bridge.py deleted file mode 100644 index 1a1ad4b9..00000000 --- a/src/twinkle/megatron/model/swift_bridge.py +++ /dev/null @@ -1,673 +0,0 @@ -# Copyright (c) twinkle authors. All rights reserved. -"""Bridge module for Megatron-Core weight conversion. - -TODO: Remove dependency on swift package. The bridge logic should be -implemented independently in twinkle to avoid external dependencies. - -This module provides: -1. TwinkleArgs: A dataclass that mimics megatron.training.get_args() return value -2. MegatronBridgeInitializer: Creates Megatron models with proper initialization -""" -import os -from contextlib import contextmanager -from dataclasses import dataclass -from functools import partial -from typing import Any, Optional - -import torch.distributed as dist - -try: - from safetensors.torch import safe_open - SAFETENSORS_AVAILABLE = True -except ImportError: - SAFETENSORS_AVAILABLE = False - - -# Cache for Swift bridge availability check -_SWIFT_BRIDGE_AVAILABLE = None -_SWIFT_GPT_BRIDGE_CLASS = None - - -def deep_getattr(obj, attr: str, default=None): - """Get nested attribute from object using dot notation.""" - try: - for key in attr.split('.'): - obj = getattr(obj, key) - return obj - except AttributeError: - return default - - -def is_last_rank() -> bool: - """Check if current process is the last rank.""" - if not dist.is_initialized(): - return True - return dist.get_rank() == dist.get_world_size() - 1 - - -class LazyTensor: - """Lazy tensor wrapper for deferred loading.""" - def __init__(self, tensor=None, loader=None): - self.tensor = tensor - self.loader = loader - - def load(self): - if self.tensor is None: - return self.loader() - return self.tensor - - -class SafetensorLazyLoader: - """Lazy loader for safetensor files.""" - def __init__(self, hf_model_dir: str, is_peft_format: bool = False): - self.hf_model_dir = hf_model_dir - self.is_peft_format = is_peft_format - self._weight_map = {} - self._file_handles = {} - self._load_index() - - def _open_file(self, filename: str): - if filename not in self._file_handles: - file_path = os.path.join(self.hf_model_dir, filename) - self._file_handles[filename] = safe_open(file_path, framework='pt') - return self._file_handles[filename] - - def _load_index(self): - import json - index_path = os.path.join(self.hf_model_dir, 'model.safetensors.index.json') - if os.path.exists(index_path): - with open(index_path, 'r') as f: - self._weight_map = json.load(f).get('weight_map', {}) - else: - safetensors_fname = 'adapter_model.safetensors' if self.is_peft_format else 'model.safetensors' - safetensors_file = os.path.join(self.hf_model_dir, safetensors_fname) - if os.path.exists(safetensors_file): - with safe_open(safetensors_file, framework='pt') as f: - for key in f.keys(): - self._weight_map[key] = safetensors_fname - - def get_state_dict(self): - return {k: LazyTensor(loader=partial(self._load_tensor, key=k)) for k in self._weight_map.keys()} - - def _load_tensor(self, key): - return self._open_file(self._weight_map[key]).get_tensor(key) - - def close(self): - self._file_handles.clear() - - def __enter__(self): - return self - - def __exit__(self, *args): - self.close() - - -@dataclass -class TwinkleArgs: - """Args class that mimics megatron.training.get_args() return value. - - TODO: Remove swift dependency. This class is currently designed to be compatible - with external GPTBridge. Once independent bridge logic is implemented, this - can be simplified. - """ - # Model architecture - hidden_size: int = 4096 - num_attention_heads: int = 32 - num_query_groups: int = 32 - num_layers: int = 32 - ffn_hidden_size: int = 11008 - padded_vocab_size: int = 32000 - - # Model options - group_query_attention: bool = False - add_qkv_bias: bool = False - add_bias_linear: bool = False - qk_layernorm: bool = False - multi_latent_attention: bool = False - untie_embeddings_and_output_weights: bool = True - - # MoE - num_experts: Optional[int] = None - moe_shared_expert_intermediate_size: Optional[int] = None - moe_router_enable_expert_bias: bool = False - - # MLA (Multi-Latent Attention) - for DeepSeek models - q_lora_rank: Optional[int] = None - kv_lora_rank: int = 32 - - # MTP (Multi-Token Prediction) - mtp_num_layers: int = 0 - - # Parallelism - tensor_model_parallel_size: int = 1 - pipeline_model_parallel_size: int = 1 - expert_model_parallel_size: int = 1 - expert_tensor_parallel_size: int = 1 - context_parallel_size: int = 1 - sequence_parallel: bool = False - - distributed_timeout_minutes: int = 300000 - distributed_backend: str = 'nccl' - local_rank: int = 0 - rank: int = 0 - world_size: int = 1 - - # Paths and identifiers - model_dir: str = '' - hf_model_type: str = 'qwen2' - - # Task type - task_type: str = 'causal_lm' - - # Save settings - max_shard_size: str = '5GB' - - # Multimodal - is_multimodal: bool = False - - # Hub settings - use_hf: bool = False - hub_token: Optional[str] = None - - # Additional Megatron settings - fp16: bool = False - bf16: bool = True - accumulate_allreduce_grads_in_fp32: bool = False - async_tensor_model_parallel_allreduce: bool = False - use_distributed_optimizer: bool = False - overlap_grad_reduce: bool = False - overlap_param_gather: bool = False - - # Softmax type - softmax_type: str = 'vanilla' - - # Extra Megatron arguments - padding_free: bool = True - mlp_padding_free: bool = False - check_model: bool = True - initialize_embedding: bool = False - rope_scaling: Optional[Any] = None - torch_dtype: Optional[Any] = None - model: Optional[str] = None - model_type: Optional[str] = None - load_safetensors: Optional[bool] = None - save_safetensors: bool = True - adapters: Optional[Any] = None - merge_lora: Optional[bool] = None - - # Training settings - micro_batch_size: int = 1 - global_batch_size: int = 16 - recompute_granularity: str = 'selective' - recompute_method: Optional[str] = None - recompute_num_layers: Optional[int] = None - use_cpu_initialization: bool = False - deterministic_mode: bool = False - no_masked_softmax_fusion: bool = False - no_bias_dropout_fusion: Optional[bool] = None - no_bias_swiglu_fusion: bool = False - no_rope_fusion: Optional[bool] = None - - # LoRA settings - train_type: Optional[str] = None - lora_rank: int = 8 - lora_alpha: int = 8 - - @classmethod - def from_hf_config(cls, hf_config: Any, tp_size: int = 1, pp_size: int = 1, - ep_size: int = 1, etp_size: Optional[int] = None, - model_dir: str = '', padded_vocab_size: Optional[int] = None, - use_hf: bool = False, hub_token: Optional[str] = None): - """Create TwinkleArgs from HuggingFace config. - - Args: - hf_config: HuggingFace model configuration. - tp_size: Tensor parallel size. - pp_size: Pipeline parallel size. - ep_size: Expert parallel size. - etp_size: Expert tensor parallel size (defaults to tp_size). - model_dir: Path to model directory. - padded_vocab_size: Padded vocabulary size (auto-computed if None). - use_hf: Whether to use HuggingFace Hub (vs ModelScope). - hub_token: Hub token for authentication. - """ - import os - - vocab_size = getattr(hf_config, 'vocab_size', 32000) - if padded_vocab_size is None: - # Pad to multiple of tp_size * 128 for efficiency - divisor = tp_size * 128 - padded_vocab_size = ((vocab_size + divisor - 1) // divisor) * divisor - - num_attention_heads = getattr(hf_config, 'num_attention_heads', 32) - num_query_groups = getattr(hf_config, 'num_key_value_heads', num_attention_heads) - model_type = getattr(hf_config, 'model_type', 'qwen2') - - # Determine QKV bias - Qwen2 has bias by default but config doesn't expose it - if hasattr(hf_config, 'attention_bias'): - add_qkv_bias = hf_config.attention_bias - elif model_type in ('qwen2', 'qwen2_5'): - add_qkv_bias = True - else: - add_qkv_bias = False - - # MoE config - num_experts = getattr(hf_config, 'num_experts', None) or \ - getattr(hf_config, 'n_routed_experts', None) or \ - getattr(hf_config, 'num_local_experts', None) - - # QK layernorm (Qwen3) - qk_layernorm = getattr(hf_config, 'qk_layernorm', False) or \ - getattr(hf_config, 'use_qk_norm', False) - - # MLA settings (DeepSeek) - q_lora_rank = getattr(hf_config, 'q_lora_rank', None) - multi_latent_attention = q_lora_rank is not None or \ - getattr(hf_config, 'kv_lora_rank', None) is not None - - # Get distributed settings from environment - local_rank = int(os.environ.get('LOCAL_RANK', 0)) - rank = int(os.environ.get('RANK', 0)) - world_size = int(os.environ.get('WORLD_SIZE', 1)) - - return cls( - hidden_size=getattr(hf_config, 'hidden_size', 4096), - num_attention_heads=num_attention_heads, - num_query_groups=num_query_groups, - num_layers=getattr(hf_config, 'num_hidden_layers', 32), - ffn_hidden_size=getattr(hf_config, 'intermediate_size', 11008), - padded_vocab_size=padded_vocab_size, - group_query_attention=num_query_groups != num_attention_heads, - add_qkv_bias=add_qkv_bias, - add_bias_linear=getattr(hf_config, 'mlp_bias', False), - qk_layernorm=qk_layernorm, - multi_latent_attention=multi_latent_attention, - untie_embeddings_and_output_weights=not getattr(hf_config, 'tie_word_embeddings', False), - num_experts=num_experts, - moe_shared_expert_intermediate_size=getattr(hf_config, 'shared_expert_intermediate_size', None), - q_lora_rank=q_lora_rank, - kv_lora_rank=getattr(hf_config, 'kv_lora_rank', 32), - tensor_model_parallel_size=tp_size, - pipeline_model_parallel_size=pp_size, - expert_model_parallel_size=ep_size, - expert_tensor_parallel_size=etp_size or tp_size, - local_rank=local_rank, - rank=rank, - world_size=world_size, - model_dir=model_dir, - hf_model_type=model_type, - use_hf=use_hf, - hub_token=hub_token, - adapters=[], # Initialize as empty list - ) - - -# ============================================================================= -# GPTBridge Adapter -# TODO: Implement independent bridge logic to remove swift dependency. -# ============================================================================= -def _import_swift_bridge(): - """Import GPTBridge from external package. - - TODO: Implement independent bridge logic in twinkle. The weight conversion - between HuggingFace and Megatron formats should be self-contained. - - Returns: - GPTBridge class if available, None otherwise. - """ - global _SWIFT_BRIDGE_AVAILABLE, _SWIFT_GPT_BRIDGE_CLASS - - if _SWIFT_BRIDGE_AVAILABLE is not None: - return _SWIFT_GPT_BRIDGE_CLASS - - try: - from swift.utils import disable_safe_ddp_context_use_barrier - - with disable_safe_ddp_context_use_barrier(): - from swift.megatron.model.gpt_bridge import GPTBridge - - _SWIFT_BRIDGE_AVAILABLE = True - _SWIFT_GPT_BRIDGE_CLASS = GPTBridge - return GPTBridge - except ImportError as e: - _SWIFT_BRIDGE_AVAILABLE = False - _SWIFT_GPT_BRIDGE_CLASS = None - return None - except Exception as e: - import traceback - print(f"Warning: Failed to import GPTBridge: {e}") - traceback.print_exc() - _SWIFT_BRIDGE_AVAILABLE = False - _SWIFT_GPT_BRIDGE_CLASS = None - return None - - -def use_swift_bridge() -> bool: - """Check if GPTBridge is available.""" - _import_swift_bridge() - return _SWIFT_BRIDGE_AVAILABLE is True - - -class SwiftBridgeAdapter: - """Adapter to use swift's GPTBridge with twinkle's TwinkleArgs. - - TODO: Remove swift dependency. Implement independent bridge logic in twinkle. - - This class wraps swift's GPTBridge for weight loading/saving between - HuggingFace and Megatron formats. - """ - - def __init__(self, args: TwinkleArgs, hf_model=None, disable_tqdm: bool = False): - self.args = args - self.hf_model = hf_model - self.disable_tqdm = disable_tqdm - self._swift_bridge = None - - self._init_swift_bridge() - - def _init_swift_bridge(self): - """Initialize swift's GPTBridge with our args.""" - GPTBridge = _import_swift_bridge() - if GPTBridge is None: - raise ImportError( - "swift package is required for Megatron weight loading. " - "Please install: pip install ms-swift" - ) - - # Use Megatron's official set_args to set global args - from megatron.training.global_vars import set_args, get_args - - # Check if args already set - try: - existing_args = get_args() - # Args already initialized, we'll use existing - self._swift_bridge = GPTBridge(disable_tqmd=self.disable_tqdm) - except AssertionError: - # Args not initialized, set our args - set_args(self.args) - self._swift_bridge = GPTBridge(disable_tqmd=self.disable_tqdm) - - def load_weights(self, mg_model, hf_model_dir: str, is_peft_format: bool = False): - """Load weights from HuggingFace checkpoint into Megatron model.""" - self._swift_bridge.load_weights(mg_model, hf_model_dir, is_peft_format) - - def save_weights(self, mg_models, output_dir: str, hf_model_dir: str = None, is_peft_format: bool = False): - """Save weights in HuggingFace format.""" - self._swift_bridge.save_weights(mg_models, output_dir, is_peft_format) - - -def create_bridge_adapter( - hf_config: Any, - tp_size: int = 1, - pp_size: int = 1, - ep_size: int = 1, - model_dir: str = '', - padded_vocab_size: Optional[int] = None, -) -> SwiftBridgeAdapter: - """Create a bridge adapter for weight loading/saving. - - TODO: Remove swift dependency. Implement independent bridge logic. - - Args: - hf_config: HuggingFace model config. - tp_size: Tensor parallel size. - pp_size: Pipeline parallel size. - ep_size: Expert parallel size. - model_dir: Path to model directory. - padded_vocab_size: Padded vocabulary size. - - Returns: - SwiftBridgeAdapter instance. - """ - args = TwinkleArgs.from_hf_config( - hf_config, - tp_size=tp_size, - pp_size=pp_size, - ep_size=ep_size, - model_dir=model_dir, - padded_vocab_size=padded_vocab_size, - ) - - return SwiftBridgeAdapter(args) - - -def create_megatron_model_with_swift( - model_path: str, - tp_size: int = 1, - pp_size: int = 1, - ep_size: int = 1, - params_dtype=None, - use_cpu_initialization: bool = True, - attention_backend: str = 'unfused', - load_weights: bool = True, -): - """Create Megatron model using swift's initialization flow. - - TODO: Remove swift dependency. Implement independent initialization logic. - - Args: - model_path: Path to HuggingFace model or model ID. - tp_size: Tensor parallel size. - pp_size: Pipeline parallel size. - ep_size: Expert parallel size. - params_dtype: Parameter dtype (default: torch.bfloat16). - use_cpu_initialization: Initialize on CPU first (for memory efficiency). - attention_backend: Attention backend ('unfused' for precision, 'flash' for speed). - load_weights: Whether to load weights. - - Returns: - Tuple of (model, bridge, megatron_model_meta). - """ - import torch - from transformers import AutoConfig - - if params_dtype is None: - params_dtype = torch.bfloat16 - - # Download model if needed - if not os.path.isdir(model_path): - try: - from modelscope import snapshot_download - model_path = snapshot_download(model_path) - except ImportError: - from huggingface_hub import snapshot_download - model_path = snapshot_download(model_path) - - # Load HF config - hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) - - # Import Swift modules with barrier disabled - from swift.utils import disable_safe_ddp_context_use_barrier - - with disable_safe_ddp_context_use_barrier(): - from swift.megatron import ( - MegatronArguments, convert_hf_config, get_megatron_model_meta - ) - - from megatron.training.initialize import initialize_megatron - from megatron.training import get_args - - # Check if Megatron is already initialized - try: - existing_args = get_args() - megatron_initialized = True - except AssertionError: - megatron_initialized = False - - # Get model meta first to get extra_args_provider - megatron_model_meta = get_megatron_model_meta(hf_config.model_type) - if megatron_model_meta is None: - raise ValueError(f'Model type {hf_config.model_type} not supported by Swift') - - if not megatron_initialized: - # Convert HF config to Megatron config kwargs - config_kwargs = convert_hf_config(hf_config) - - # Create MegatronArguments - megatron_args = MegatronArguments( - model=model_path, - tensor_model_parallel_size=tp_size, - pipeline_model_parallel_size=pp_size, - expert_model_parallel_size=ep_size, - torch_dtype=params_dtype, - use_cpu_initialization=use_cpu_initialization, - attention_backend=attention_backend, - **config_kwargs, - ) - - # Parse to Megatron format - extra_args = megatron_args.parse_to_megatron() - - # Initialize Megatron - extra_args_provider = megatron_model_meta.extra_args_provider - initialize_megatron(extra_args_provider=extra_args_provider, args_defaults=extra_args) - - # Determine pre_process and post_process based on pipeline stage - from megatron.core import parallel_state as mpu - pre_process = mpu.is_pipeline_first_stage() - post_process = mpu.is_pipeline_last_stage() - - model = megatron_model_meta.model_provider(pre_process=pre_process, post_process=post_process) - - # Load weights if requested - bridge = None - if load_weights: - bridge = megatron_model_meta.bridge_cls() - bridge.load_weights(model, model_path) - - return model, bridge, megatron_model_meta - - -class MegatronBridgeInitializer: - """Megatron model initializer using bridge-based initialization flow. - - TODO: Remove swift dependency. Implement independent initialization logic. - - Example: - initializer = MegatronBridgeInitializer( - tp_size=2, - pp_size=1, - params_dtype=torch.bfloat16, - ) - model = initializer.create_model('Qwen/Qwen2.5-7B-Instruct') - """ - - def __init__( - self, - tp_size: int = 1, - pp_size: int = 1, - ep_size: int = 1, - params_dtype=None, - use_cpu_initialization: bool = True, - attention_backend: str = 'flash', # Use flash for training performance - ): - """Initialize MegatronBridgeInitializer. - - Args: - tp_size: Tensor parallel size. - pp_size: Pipeline parallel size. - ep_size: Expert parallel size. - params_dtype: Parameter dtype (default: torch.bfloat16). - use_cpu_initialization: Initialize on CPU first. - attention_backend: Attention backend. - """ - import torch - - self.tp_size = tp_size - self.pp_size = pp_size - self.ep_size = ep_size - self.params_dtype = params_dtype if params_dtype is not None else torch.bfloat16 - self.use_cpu_initialization = use_cpu_initialization - self.attention_backend = attention_backend - - self._model = None - self._bridge = None - self._model_meta = None - self._hf_config = None - - def create_model( - self, - model_path: str, - load_weights: bool = True, - ): - """Create Megatron model from HuggingFace checkpoint. - - Args: - model_path: Path to HuggingFace model or model ID. - load_weights: Whether to load weights. - - Returns: - Megatron model. - """ - from transformers import AutoConfig - - # Download model if needed - if not os.path.isdir(model_path): - try: - from modelscope import snapshot_download - model_path = snapshot_download(model_path) - except ImportError: - from huggingface_hub import snapshot_download - model_path = snapshot_download(model_path) - - # Store HF config - self._hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) - - self._model, self._bridge, self._model_meta = create_megatron_model_with_swift( - model_path=model_path, - tp_size=self.tp_size, - pp_size=self.pp_size, - ep_size=self.ep_size, - params_dtype=self.params_dtype, - use_cpu_initialization=self.use_cpu_initialization, - attention_backend=self.attention_backend, - load_weights=load_weights, - ) - - return self._model - - @property - def hf_config(self): - """Get the HuggingFace config.""" - return self._hf_config - - @property - def bridge(self): - """Get the Swift bridge instance.""" - return self._bridge - - @property - def model_meta(self): - """Get the Megatron model meta.""" - return self._model_meta - - def load_weights(self, model, model_path: str): - """Load weights into an existing model. - - Args: - model: Megatron model. - model_path: Path to HuggingFace checkpoint. - """ - if self._bridge is None: - # Create bridge from model meta - if self._model_meta is None: - raise ValueError("Must call create_model first or provide model_meta") - self._bridge = self._model_meta.bridge_cls() - - self._bridge.load_weights(model, model_path) - - def save_weights(self, models, output_dir: str, is_peft_format: bool = False): - """Save weights in HuggingFace format. - - Args: - models: Megatron model(s). - output_dir: Output directory. - is_peft_format: Whether to save in PEFT format. - """ - if self._bridge is None: - raise ValueError("Must load weights first") - - if not isinstance(models, (list, tuple)): - models = [models] - - self._bridge.save_weights(models, output_dir, is_peft_format=is_peft_format) diff --git a/src/twinkle/megatron/worker.py b/src/twinkle/megatron/worker.py new file mode 100644 index 00000000..1afe7818 --- /dev/null +++ b/src/twinkle/megatron/worker.py @@ -0,0 +1,368 @@ +# Copyright (c) twinkle authors. All rights reserved. +"""Megatron Worker for Ray-based distributed training. + +This module provides MegatronWorkerGroup for coordinated Ray actor-based +training with Megatron's collective operations. + +NOTE: Currently PP > 1 is required for Ray mode training with LoRA. +PP=1 has gradient flow issues that need further investigation. + +Example: + worker_group = MegatronWorkerGroup(world_size=4, tp_size=2, pp_size=2) + worker_group.init_all() + worker_group.create_model_all('Qwen/Qwen2.5-0.5B-Instruct') + worker_group.add_lora_all({'target_modules': ['linear_qkv'], 'r': 8}) + worker_group.set_optimizer_all(lr=1e-4) + + for batch in dataloader: + losses = worker_group.forward_backward_all(batch) + worker_group.step_all() +""" +import os +from typing import Any, Dict, List + +import torch + + +def get_megatron_worker_class(): + """Returns a Ray remote class for Megatron workers.""" + import ray + + @ray.remote(num_gpus=1) + class MegatronWorker: + """Ray actor for a single Megatron rank.""" + + def __init__( + self, + rank: int, + world_size: int, + master_addr: str, + master_port: int, + tp_size: int = 1, + pp_size: int = 1, + cp_size: int = 1, + ep_size: int = 1, + ): + self.rank = rank + self.world_size = world_size + self.master_addr = master_addr + self.master_port = master_port + self.tp_size = tp_size + self.pp_size = pp_size + self.cp_size = cp_size + self.ep_size = ep_size + self.model = None + self.optimizer = None + self.hf_config = None + + def _get_local_gpu_id(self) -> int: + """Get local GPU ID for this actor.""" + import ray + cvd = os.environ.get("CUDA_VISIBLE_DEVICES", None) + if cvd is None: + gpu_ids = ray.get_gpu_ids() + return int(gpu_ids[0]) if gpu_ids else 0 + else: + gpu_ids = ray.get_gpu_ids() + if gpu_ids: + return cvd.split(",").index(str(int(gpu_ids[0]))) + return 0 + + def init(self, model_config: Dict[str, Any] = None) -> bool: + """Initialize distributed and Megatron parallel state.""" + import torch.distributed as dist + from datetime import timedelta + + os.environ["MASTER_ADDR"] = self.master_addr + os.environ["MASTER_PORT"] = str(self.master_port) + os.environ["WORLD_SIZE"] = str(self.world_size) + os.environ["RANK"] = str(self.rank) + + local_rank = self._get_local_gpu_id() + os.environ["LOCAL_RANK"] = str(local_rank) + torch.cuda.set_device(local_rank) + + if not dist.is_initialized(): + dist.init_process_group(backend="nccl", timeout=timedelta(minutes=10)) + + from megatron.core import parallel_state as mpu + if not mpu.is_initialized(): + mpu.initialize_model_parallel( + tensor_model_parallel_size=self.tp_size, + pipeline_model_parallel_size=self.pp_size, + context_parallel_size=self.cp_size, + expert_model_parallel_size=self.ep_size, + ) + + from megatron.core import tensor_parallel + torch.manual_seed(42 + self.rank) + tensor_parallel.model_parallel_cuda_manual_seed(42 + self.rank) + + print(f"[Worker rank={self.rank}] Initialized TP={self.tp_size} PP={self.pp_size}") + return True + + def create_model( + self, + pretrained_model_name_or_path: str, + mixed_precision: str = 'bf16', + recompute_granularity: str = 'full', + **kwargs, + ) -> bool: + """Create Megatron model.""" + from twinkle.megatron.model.bridge import TwinkleBridgeInitializer + + dtype_map = {'fp32': torch.float32, 'fp16': torch.float16, 'bf16': torch.bfloat16} + params_dtype = dtype_map.get(mixed_precision, torch.bfloat16) + + initializer = TwinkleBridgeInitializer( + tp_size=self.tp_size, + pp_size=self.pp_size, + cp_size=self.cp_size, + ep_size=self.ep_size, + params_dtype=params_dtype, + recompute_granularity=recompute_granularity, + **kwargs, + ) + + self.model = initializer.create_model(pretrained_model_name_or_path) + self.hf_config = initializer._hf_config + print(f"[Worker rank={self.rank}] Model created") + return True + + def add_lora(self, lora_config: Dict[str, Any]) -> bool: + """Add LoRA adapter.""" + from peft import get_peft_model, LoraConfig + from peft.tuners.tuners_utils import BaseTuner + import torch.nn as nn + + # Patch for Megatron's TransformerConfig + orig_fn = BaseTuner._get_tied_target_modules + def patched_fn(self, model: nn.Module): + try: + return orig_fn(self, model) + except AttributeError: + return [] + BaseTuner._get_tied_target_modules = patched_fn + + from twinkle.megatron.utils import set_linear_is_expert + set_linear_is_expert(self.model) + + config = LoraConfig(**lora_config) + self.model = get_peft_model(self.model, config) + + # Add compatibility methods for Megatron DDP + if not hasattr(self.model, 'finish_grad_sync'): + self.model.finish_grad_sync = lambda: None + if not hasattr(self.model, 'start_grad_sync'): + self.model.start_grad_sync = lambda: None + if not hasattr(self.model, 'no_sync'): + from contextlib import nullcontext + self.model.no_sync = nullcontext + + # Create a dummy ddp_config that has necessary attributes + if not hasattr(self.model, 'ddp_config') or self.model.ddp_config is None: + class DummyDDPConfig: + use_megatron_fsdp = False + use_distributed_optimizer = False + overlap_grad_reduce = False + overlap_param_gather = False + bucket_size = None + self.model.ddp_config = DummyDDPConfig() + + trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad) + print(f"[Worker rank={self.rank}] LoRA added, trainable params={trainable}") + return True + + def set_optimizer(self, lr: float = 1e-4, **kwargs) -> bool: + """Set up optimizer.""" + from torch.optim import AdamW + trainable_params = [p for p in self.model.parameters() if p.requires_grad] + self.optimizer = AdamW(trainable_params, lr=lr, **kwargs) + print(f"[Worker rank={self.rank}] Optimizer set") + return True + + def forward_backward(self, batch: Dict[str, torch.Tensor]) -> float: + """Execute forward-backward pass.""" + from functools import partial + from megatron.core.pipeline_parallel import get_forward_backward_func + + local_rank = self._get_local_gpu_id() + batch = {k: v.cuda(local_rank) if isinstance(v, torch.Tensor) else v + for k, v in batch.items()} + + seq_length = batch['input_ids'].shape[1] + micro_batch_size = batch['input_ids'].shape[0] + + def forward_step_func(data_iterator, model): + batch = next(data_iterator) + input_ids = batch['input_ids'] + labels = batch.get('labels') + attention_mask = batch.get('attention_mask') + + position_ids = torch.arange( + input_ids.shape[1], device=input_ids.device, dtype=torch.long + ).unsqueeze(0).expand(input_ids.shape[0], -1) + + output = model( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + labels=labels, + ) + + def loss_func(labels, output): + mask = (labels != -100).float() + loss = (output.float().view(-1) * mask.view(-1)).sum() / mask.sum().clamp(min=1) + return loss, {'loss': loss.detach()} + + return output, partial(loss_func, labels) + + self.model.train() + forward_backward_func = get_forward_backward_func() + + losses = forward_backward_func( + forward_step_func=forward_step_func, + data_iterator=iter([batch]), + model=[self.model], + num_microbatches=1, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + forward_only=False, + ) + + if losses and isinstance(losses[0], dict) and 'loss' in losses[0]: + return losses[0]['loss'].item() + return 0.0 + + def step(self) -> bool: + """Optimizer step.""" + if self.optimizer is None: + return False + torch.nn.utils.clip_grad_norm_( + [p for p in self.model.parameters() if p.requires_grad], 1.0 + ) + self.optimizer.step() + self.optimizer.zero_grad() + return True + + def cleanup(self) -> bool: + """Clean up resources.""" + import torch.distributed as dist + from megatron.core import parallel_state as mpu + try: + if dist.is_initialized(): + dist.barrier() + if mpu.is_initialized(): + mpu.destroy_model_parallel() + if dist.is_initialized(): + dist.destroy_process_group() + except Exception as e: + print(f"[Worker rank={self.rank}] Cleanup error: {e}") + return True + + return MegatronWorker + + +class MegatronWorkerGroup: + """Manager for coordinated Megatron Ray workers. + + Handles synchronized creation, initialization, and execution + of Megatron workers for distributed training. + + NOTE: PP > 1 is required for training with LoRA. PP=1 has gradient issues. + """ + + def __init__( + self, + world_size: int, + tp_size: int = 1, + pp_size: int = 1, + cp_size: int = 1, + ep_size: int = 1, + master_addr: str = None, + master_port: int = None, + ): + import ray + import socket + + # Warn if PP=1 (known gradient issue) + if pp_size == 1: + print("[MegatronWorkerGroup] WARNING: PP=1 has known gradient issues. " + "Training loss may not decrease. Use PP > 1 for training.") + + self.world_size = world_size + self.tp_size = tp_size + self.pp_size = pp_size + self.cp_size = cp_size + self.ep_size = ep_size + + if master_addr is None: + master_addr = ray.util.get_node_ip_address() + if master_port is None: + with socket.socket() as sock: + sock.bind(("", 0)) + master_port = sock.getsockname()[1] + + self.master_addr = master_addr + self.master_port = master_port + + MegatronWorker = get_megatron_worker_class() + self.workers = [ + MegatronWorker.remote( + rank=rank, + world_size=world_size, + master_addr=master_addr, + master_port=master_port, + tp_size=tp_size, + pp_size=pp_size, + cp_size=cp_size, + ep_size=ep_size, + ) + for rank in range(world_size) + ] + print(f"[MegatronWorkerGroup] Created {world_size} workers (TP={tp_size}, PP={pp_size})") + + def init_all(self, model_config: Dict[str, Any] = None) -> List[bool]: + """Initialize all workers.""" + import ray + return ray.get([w.init.remote(model_config) for w in self.workers]) + + def create_model_all(self, pretrained_model_name_or_path: str, **kwargs) -> List[bool]: + """Create model on all workers.""" + import ray + return ray.get([w.create_model.remote(pretrained_model_name_or_path, **kwargs) for w in self.workers]) + + def add_lora_all(self, lora_config: Dict[str, Any]) -> List[bool]: + """Add LoRA to all workers.""" + import ray + return ray.get([w.add_lora.remote(lora_config) for w in self.workers]) + + def set_optimizer_all(self, lr: float = 1e-4, **kwargs) -> List[bool]: + """Set optimizer on all workers.""" + import ray + return ray.get([w.set_optimizer.remote(lr, **kwargs) for w in self.workers]) + + def forward_backward_all(self, batch: Dict[str, torch.Tensor]) -> List[float]: + """Execute forward/backward on all workers.""" + import ray + return ray.get([w.forward_backward.remote(batch) for w in self.workers]) + + def step_all(self) -> List[bool]: + """Optimizer step on all workers.""" + import ray + return ray.get([w.step.remote() for w in self.workers]) + + def cleanup_all(self) -> List[bool]: + """Cleanup all workers.""" + import ray + return ray.get([w.cleanup.remote() for w in self.workers]) + + def shutdown(self): + """Shutdown all workers.""" + import ray + for worker in self.workers: + try: + ray.kill(worker) + except Exception: + pass + self.workers = [] diff --git a/src/twinkle/model/megatron.py b/src/twinkle/model/megatron.py index 77c3f691..462cd925 100644 --- a/src/twinkle/model/megatron.py +++ b/src/twinkle/model/megatron.py @@ -111,6 +111,8 @@ def __init__( use_distributed_optimizer: bool = True, load_weights: bool = True, use_megatron_bridge: bool = True, # Use bridge-based initialization (recommended) + recompute_granularity: Optional[str] = 'selective', # Activation checkpointing + recompute_modules: Optional[list] = None, # Modules to recompute **kwargs, ): check_megatron_available() @@ -120,6 +122,8 @@ def __init__( self.device_mesh = device_mesh self.mixed_precision = mixed_precision self.use_megatron_bridge = use_megatron_bridge + self.recompute_granularity = recompute_granularity + self.recompute_modules = recompute_modules # Load HuggingFace config first model_path = HubOperation.download_model(pretrained_model_name_or_path) @@ -197,11 +201,12 @@ def _create_megatron_model_with_bridge( ) -> nn.Module: """Create Megatron model using bridge-based initialization flow. - This approach uses the bridge initialization which includes: + This approach uses TwinkleBridgeInitializer for independent initialization + It includes: - Proper config conversion from HuggingFace to Megatron format - Correct Megatron initialization (initialize_megatron) - - Correct model creation (model_provider) - - All necessary patches (RoPE, TransformerLayer, etc.) + - Correct model creation + - Weight loading with TwinkleGPTBridge Args: model_path: Path to HuggingFace model. @@ -212,16 +217,21 @@ def _create_megatron_model_with_bridge( Returns: Megatron model on GPU. """ - from twinkle.megatron.model.swift_bridge import MegatronBridgeInitializer + from twinkle.megatron.model.bridge import TwinkleBridgeInitializer # Create bridge-based initializer - self._bridge_initializer = MegatronBridgeInitializer( + self._bridge_initializer = TwinkleBridgeInitializer( tp_size=self.strategy.tp_size, pp_size=self.strategy.pp_size, + cp_size=self.strategy.cp_size, ep_size=self.strategy.ep_size, params_dtype=params_dtype, - use_cpu_initialization=True, + use_cpu_initialization=False, attention_backend='flash', # Use flash for training performance + recompute_granularity=self.recompute_granularity, + recompute_modules=self.recompute_modules, + recompute_method=getattr(self, 'recompute_method', None), + recompute_num_layers=getattr(self, 'recompute_num_layers', None), ) # Create model (this calls initialize_megatron internally) @@ -279,13 +289,25 @@ def _create_megatron_model_native( return model def _move_model_to_gpu(self, model: nn.Module) -> nn.Module: + """Move model to correct GPU device. + + This method handles moving parameters, buffers, and any cached tensors + (like RoPE embeddings) to the correct device for distributed training. + """ # Determine the target device based on local rank local_rank = dist.get_rank() % torch.cuda.device_count() if dist.is_initialized() else 0 device = torch.device(f'cuda:{local_rank}') + # Set CUDA device explicitly + torch.cuda.set_device(local_rank) + # Move all parameters and buffers to GPU model = model.to(device) + # Force synchronize to ensure all transfers complete + if torch.cuda.is_available(): + torch.cuda.synchronize(device) + return model def _lazy_wrap_model(self): @@ -353,7 +375,12 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, if processor is not None: inputs: Dict[str, Any] = processor(inputs) - labels = inputs.pop('labels', None) + labels = inputs.get('labels', None) + if 'labels' in inputs: + try: + del inputs['labels'] + except (TypeError, KeyError): + pass # Some dict-like types don't support deletion # Forward through model outputs = self._forward_step(inputs) @@ -485,7 +512,7 @@ def backward(self, **kwargs): loss_value.backward() optimizer_config.cur_step += 1 - @remote_function(collect='avg') + @remote_function(dispatch='all', collect='avg') def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs): """Combined forward and backward pass using Megatron's scheduler. @@ -522,31 +549,70 @@ def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Tr inputs = processor(inputs) # Store labels before removing from inputs - labels = inputs.pop('labels', None) + labels = inputs.get('labels', None) + if 'labels' in inputs: + try: + del inputs['labels'] + except (TypeError, KeyError): + pass # Some dict-like types don't support deletion + + # Get CP size for sequence padding and splitting + cp_size = self.strategy.cp_size + cp_rank = mpu.get_context_parallel_rank() if cp_size > 1 else 0 # Get sequence length and batch size - seq_length = inputs['input_ids'].shape[1] if 'input_ids' in inputs else 1 + # Note: Megatron's schedule internally divides seq_length by cp_size + # So we pass the padded full sequence length here + original_seq_length = inputs['input_ids'].shape[1] if 'input_ids' in inputs else 1 micro_batch_size = inputs['input_ids'].shape[0] if 'input_ids' in inputs else 1 + # For CP > 1, pad seq_length to be divisible by 2*cp_size + if cp_size > 1: + divisor = 2 * cp_size + if original_seq_length % divisor != 0: + seq_length = original_seq_length + (divisor - original_seq_length % divisor) + else: + seq_length = original_seq_length + else: + seq_length = original_seq_length + # Move labels to GPU if needed if labels is not None and not isinstance(labels, torch.Tensor): labels = torch.tensor(labels, device=torch.cuda.current_device()) elif labels is not None: labels = labels.to(torch.cuda.current_device()) - # Define loss function that matches Megatron's expected signature - # loss_func(output_tensor) -> (loss, {str: tensor}) - def loss_func(labels_tensor, loss_instance, output_tensor): - if labels_tensor is None or loss_instance is None: - loss = torch.tensor(0.0, device=output_tensor.device, requires_grad=True) - return loss, {'loss': loss} + def split_tensor_for_cp(tensor, dim=-1): + """Split tensor along sequence dimension for Context Parallel. - inputs_dict = {'labels': labels_tensor} - outputs_dict = {'logits': output_tensor} - loss = loss_instance(inputs_dict, outputs_dict) + With causal masking, split into 2*CP chunks and assign alternating + chunks to balance workload across CP ranks. + For CP rank i: chunks [i, 2*CP-1-i] - # Megatron expects (loss, {str: tensor}) for logging - return loss, {'loss': loss.detach()} + Based on Swift's split_cp_inputs implementation. + """ + if tensor is None or cp_size <= 1: + return tensor + + if dim < 0: + dim = (dim + tensor.ndim) % tensor.ndim + + seq_len = tensor.shape[dim] + + # Reshape to [batch, 2*cp_size, seq_per_chunk, ...] + view_shape = list(tensor.shape) + view_shape[dim:dim+1] = [2 * cp_size, seq_len // (2 * cp_size)] + reshaped = tensor.view(*view_shape) + + # Select chunks [cp_rank, 2*cp_size-1-cp_rank] + index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], + device='cpu', pin_memory=True).cuda(non_blocking=True) + selected = reshaped.index_select(dim, index) + + # Reshape back: [batch, 2*seq_per_chunk, ...] + out_shape = list(tensor.shape) + out_shape[dim] = seq_len // cp_size + return selected.reshape(*out_shape) # Define forward step function for Megatron # forward_step_func(data_iterator, model) -> (output_tensor, partial(loss_func)) @@ -555,6 +621,26 @@ def forward_step_func(data_iterator, model): input_ids = batch.get('input_ids') position_ids = batch.get('position_ids') attention_mask = batch.get('attention_mask') + batch_labels = batch.get('labels', labels) # Use batch labels or passed labels + + # Pad sequence for Context Parallel compatibility + # Megatron's RoPE requires seq_len % (2 * cp_size) == 0 + if cp_size > 1 and input_ids is not None: + seq_len = input_ids.shape[1] + divisor = 2 * cp_size + if seq_len % divisor != 0: + pad_len = divisor - (seq_len % divisor) + # Pad input_ids + input_ids = torch.nn.functional.pad(input_ids, (0, pad_len), value=0) + # Pad labels if present + if batch_labels is not None: + batch_labels = torch.nn.functional.pad(batch_labels, (0, pad_len), value=-100) + # Pad attention_mask if present + if attention_mask is not None: + attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_len), value=0) + # Pad position_ids if present + if position_ids is not None: + position_ids = torch.nn.functional.pad(position_ids, (0, pad_len), value=0) # Create position_ids if not provided if position_ids is None and input_ids is not None: @@ -564,15 +650,55 @@ def forward_step_func(data_iterator, model): dtype=torch.long, ).unsqueeze(0).expand(input_ids.shape[0], -1) - # Forward pass + # Split tensors for Context Parallel + # Each CP rank processes a portion of the sequence + if cp_size > 1: + input_ids = split_tensor_for_cp(input_ids, dim=-1) + position_ids = split_tensor_for_cp(position_ids, dim=-1) + attention_mask = split_tensor_for_cp(attention_mask, dim=-1) + batch_labels = split_tensor_for_cp(batch_labels, dim=-1) + + # Forward pass with labels - Megatron will compute loss internally + # This uses Megatron's compute_language_model_loss which properly handles + # vocab parallel cross entropy output_tensor = model( input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask, + labels=batch_labels, # Pass labels to let Megatron compute loss ) - # Return output and partial loss function - return output_tensor, partial(loss_func, labels, optimizer_config.loss_instance) + # Megatron's compute_language_model_loss returns per-token loss [batch, seq] + # We need to aggregate it with loss_mask + def megatron_loss_func(labels_for_mask, cp_size, output_tensor): + # output_tensor is per-token loss [batch, seq] + # Create loss mask from labels (ignore -100) + loss_mask = (labels_for_mask != -100).float() + + # Flatten and compute mean + losses = output_tensor.float().view(-1) + loss_mask_flat = loss_mask.view(-1) + + # Compute local sum and count + local_loss_sum = torch.sum(losses * loss_mask_flat) + local_count = loss_mask_flat.sum() + + # For CP > 1, aggregate loss across CP ranks + if cp_size > 1: + # Combine loss_sum and count for efficient all-reduce + loss_data = torch.cat([local_loss_sum.view(1), local_count.view(1)]) + torch.distributed.all_reduce( + loss_data, + op=torch.distributed.ReduceOp.SUM, + group=mpu.get_context_parallel_group() + ) + loss = loss_data[0] / loss_data[1].clamp(min=1) + else: + loss = local_loss_sum / local_count.clamp(min=1) + + return loss, {'loss': loss.detach()} + + return output_tensor, partial(megatron_loss_func, batch_labels, cp_size) # Get Megatron's forward-backward function # This automatically selects the right scheduler based on PP config: @@ -607,7 +733,7 @@ def forward_step_func(data_iterator, model): break # For PP > 1, broadcast loss from last PP stage to all ranks - from megatron.core import parallel_state as mpu + # Note: mpu is imported at module level, no need to reimport if mpu.get_pipeline_model_parallel_world_size() > 1: if isinstance(loss, torch.Tensor): loss_tensor = loss.detach().clone() @@ -625,6 +751,15 @@ def forward_step_func(data_iterator, model): optimizer_config.cur_step += 1 + # Critical: Synchronize all DP replicas before returning + # This ensures all DP replicas complete the same training step before + # moving to the next batch, preventing P2P communication deadlocks + dp_world_size = mpu.get_data_parallel_world_size() + if dp_world_size > 1: + # Use barrier on DP+CP group to synchronize all replicas + dp_cp_group = mpu.get_data_parallel_group(with_context_parallel=True) + dist.barrier(group=dp_cp_group) + if isinstance(loss, torch.Tensor): return loss.detach().cpu().float().numpy() return float(loss) @@ -746,8 +881,16 @@ def lr_step(self, **kwargs): def set_loss(self, loss_cls: Union[Type[Loss], str], **kwargs): """Set loss function. + NOTE: For MegatronModel, the loss is computed internally by Megatron's + GPTModel when labels are passed. This method is kept for API compatibility + but the provided loss_cls is NOT used during forward_backward. + + Megatron internally uses vocab_parallel_cross_entropy which correctly + handles tensor parallelism. This design ensures Loss classes don't need + to be aware of the training backend (Megatron vs Transformers). + Args: - loss_cls: Loss class or string name. + loss_cls: Loss class or string name (not used for Megatron). **kwargs: Additional arguments. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) @@ -758,6 +901,7 @@ def set_loss(self, loss_cls: Union[Type[Loss], str], **kwargs): loss_cls = getattr(twinkle.loss, loss_cls) else: loss_cls = Plugin.load_plugin(loss_cls, Loss) + # Keep for API compatibility, but not used in forward_backward optimizer_config.loss_instance = loss_cls() @remote_function() @@ -840,42 +984,66 @@ def save(self, output_dir: str, **kwargs): self._save_tokenizer(output_dir, adapter_name) def _save_hf_format(self, output_dir: str, adapter_name: str): - """Save in HuggingFace format using bridge adapter.""" + """Save in HuggingFace format using bridge adapter. + + For distributed training: + - All PP ranks participate in export (each has different layers) + - Only DP rank 0 actually writes to disk + - Uses barrier for synchronization + + For LoRA training: + - Saves in PEFT format (adapter_model.safetensors + adapter_config.json) + """ from twinkle.megatron.model.bridge import TwinkleBridgeAdapter import os - # Only save from last PP rank and first DP rank to avoid conflicts - if not self.strategy.is_pipeline_last_stage(): - return - - # Only let DP rank 0 save to avoid file conflicts - if hasattr(self.strategy, 'dp_rank') and self.strategy.dp_rank != 0: - return + # Check if this is LoRA training (has adapter_name other than default) + is_lora = adapter_name and adapter_name != '' + is_peft_format = is_lora - # Also check via parallel_state if available + # Create output directory on rank 0 only try: from megatron.core import parallel_state as mpu - if mpu.is_initialized() and mpu.get_data_parallel_rank() != 0: - return + dp_rank = mpu.get_data_parallel_rank() if mpu.is_initialized() else 0 except (ImportError, AssertionError): - pass - - os.makedirs(output_dir, exist_ok=True) + dp_rank = 0 + + if dp_rank == 0: + os.makedirs(output_dir, exist_ok=True) + + # Synchronize before saving + if dist.is_initialized(): + dist.barrier() + + # Calculate padded vocab size + padded_vocab_size = self._pad_vocab_size(self.hf_config.vocab_size) \ + if hasattr(self, '_pad_vocab_size') else None # Use TwinkleBridgeAdapter for weight conversion + # All ranks participate - bridge handles which ranks write adapter = TwinkleBridgeAdapter( hf_config=self.hf_config, tp_size=self.strategy.tp_size, pp_size=self.strategy.pp_size, ep_size=self.strategy.ep_size, - model_path=self.model_id, + model_path=self._model_path if hasattr(self, '_model_path') else self.model_id, + padded_vocab_size=padded_vocab_size, ) - # Use bridge to save weights in HuggingFace format - adapter.save_weights([self.model], output_dir, is_peft_format=False) - - # Save config - self.hf_config.save_pretrained(output_dir) + # Get the model (unwrap if DDP wrapped) + model = self.strategy.unwrap_model(self.model) + + # Use bridge to save weights + adapter.save_weights([model], output_dir, is_peft_format=is_peft_format) + + # Save config on rank 0 only + if dp_rank == 0: + self.hf_config.save_pretrained(output_dir) + + def _pad_vocab_size(self, vocab_size: int) -> int: + """Pad vocab size for tensor parallelism.""" + divisor = self.strategy.tp_size * 128 + return ((vocab_size + divisor - 1) // divisor) * divisor def _save_megatron_format(self, output_dir: str, adapter_name: str): """Save in Megatron checkpoint format.""" @@ -1012,6 +1180,33 @@ def add_adapter_to_model( self.model.module = model else: self.model = model + + # Add finish_grad_sync method for Megatron's finalize_model_grads compatibility + # This is needed because Megatron's forward_backward_func calls finish_grad_sync + # on model chunks, but PEFT models don't have this method by default + if not hasattr(self.model, 'finish_grad_sync'): + def finish_grad_sync(): + """Synchronize gradients across DP ranks for non-DDP models. + + This is a compatibility shim for Megatron's finalize_model_grads. + For PEFT/LoRA models, we manually all-reduce gradients. + """ + dp_world_size = mpu.get_data_parallel_world_size() + if dp_world_size > 1: + dp_cp_group = mpu.get_data_parallel_group(with_context_parallel=True) + grads = [] + for param in self.model.parameters(): + if param.requires_grad and param.grad is not None: + grads.append(param.grad.data) + + if grads: + from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + coalesced = _flatten_dense_tensors(grads) + dist.all_reduce(coalesced, op=dist.ReduceOp.AVG, group=dp_cp_group) + for grad, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): + grad.copy_(synced) + + self.model.finish_grad_sync = finish_grad_sync # Create optimizer group for adapter self.optimizer_group[adapter_name] = MegatronOptimizerGroup() diff --git a/tests/test_parallelism.py b/tests/test_parallelism.py new file mode 100644 index 00000000..1a394e0e --- /dev/null +++ b/tests/test_parallelism.py @@ -0,0 +1,405 @@ +#!/usr/bin/env python +# Copyright (c) twinkle authors. All rights reserved. +"""Test different parallelism strategies for Megatron backend in local mode. + +This script tests various parallelism configurations: +- TP (Tensor Parallel) +- PP (Pipeline Parallel) +- DP (Data Parallel) +- CP (Context Parallel) +- SP (Sequence Parallel, enabled when TP > 1) +- Combined configurations + +Uses Qwen2.5-0.5B-Instruct for faster testing. + +Usage: + python cookbook/megatron/test_parallelism.py +""" +import os +import subprocess +import sys +import time +from dataclasses import dataclass +from typing import List, Optional + +# Color codes for terminal output +GREEN = '\033[92m' +RED = '\033[91m' +YELLOW = '\033[93m' +RESET = '\033[0m' + + +@dataclass +class TestConfig: + """Test configuration for a parallelism strategy.""" + name: str + tp_size: int + pp_size: int + cp_size: int = 1 + sp_enabled: bool = False # Sequence Parallel + num_gpus: int = 0 # 0 means auto-calculate + max_steps: int = 2 + expected_to_pass: bool = True + notes: str = "" + + def __post_init__(self): + if self.num_gpus == 0: + self.num_gpus = self.tp_size * self.pp_size * self.cp_size + # Ensure at least 1 for DP + if self.num_gpus == 0: + self.num_gpus = 1 + + +# Test configurations +TEST_CONFIGS: List[TestConfig] = [ + # Basic single-GPU + TestConfig( + name="Single GPU (TP=1, PP=1)", + tp_size=1, pp_size=1, cp_size=1, + num_gpus=1, + notes="Baseline test" + ), + + # Tensor Parallel only + TestConfig( + name="TP=2 (Tensor Parallel)", + tp_size=2, pp_size=1, cp_size=1, + notes="Tests tensor sharding" + ), + TestConfig( + name="TP=4 (Tensor Parallel)", + tp_size=4, pp_size=1, cp_size=1, + notes="Larger TP" + ), + + # Pipeline Parallel only + TestConfig( + name="PP=2 (Pipeline Parallel)", + tp_size=1, pp_size=2, cp_size=1, + notes="Tests pipeline stages" + ), + TestConfig( + name="PP=4 (Pipeline Parallel)", + tp_size=1, pp_size=4, cp_size=1, + notes="More pipeline stages" + ), + + # TP + PP combinations + TestConfig( + name="TP=2, PP=2", + tp_size=2, pp_size=2, cp_size=1, + notes="Combined TP+PP" + ), + TestConfig( + name="TP=2, PP=4", + tp_size=2, pp_size=4, cp_size=1, + num_gpus=8, + notes="8-GPU TP+PP" + ), + + # Data Parallel (DP > 1) + TestConfig( + name="TP=2, PP=2, DP=2 (8 GPUs)", + tp_size=2, pp_size=2, cp_size=1, + num_gpus=8, + expected_to_pass=False, + notes="Known issue: P2P deadlock with DP > 1" + ), + + # Context Parallel + TestConfig( + name="CP=2 (Context Parallel)", + tp_size=1, pp_size=1, cp_size=2, + expected_to_pass=False, + notes="Known issue: CP communication deadlock" + ), + TestConfig( + name="TP=2, PP=2, CP=2 (8 GPUs)", + tp_size=2, pp_size=2, cp_size=2, + num_gpus=8, + expected_to_pass=False, + notes="Known issue: CP + PP deadlock" + ), + + # Sequence Parallel (with TP) + TestConfig( + name="TP=2 + SP (Sequence Parallel)", + tp_size=2, pp_size=1, cp_size=1, + sp_enabled=True, + notes="SP is typically enabled with TP" + ), + TestConfig( + name="TP=2, PP=2 + SP", + tp_size=2, pp_size=2, cp_size=1, + sp_enabled=True, + notes="Combined TP+PP+SP" + ), +] + + +def get_available_gpus() -> int: + """Get number of available GPUs.""" + try: + result = subprocess.run( + ['nvidia-smi', '--query-gpu=name', '--format=csv,noheader'], + capture_output=True, text=True, timeout=10 + ) + if result.returncode == 0: + return len(result.stdout.strip().split('\n')) + except Exception: + pass + return 0 + + +def create_test_script() -> str: + """Create a minimal test script for parallelism testing.""" + script = ''' +# Minimal Megatron parallelism test script +import os +import sys +import argparse + +# Set CUDA device before any imports +import torch +LOCAL_RANK = int(os.environ.get('LOCAL_RANK', '0')) +torch.cuda.set_device(LOCAL_RANK) + +import numpy as np +import twinkle +from twinkle import DeviceMesh, DeviceGroup, Platform, get_logger +from twinkle.model import MegatronModel +from peft import LoraConfig +from torch.optim import AdamW + +logger = get_logger() + +parser = argparse.ArgumentParser() +parser.add_argument('--tp_size', type=int, default=1) +parser.add_argument('--pp_size', type=int, default=1) +parser.add_argument('--cp_size', type=int, default=1) +parser.add_argument('--sp_enabled', action='store_true') +parser.add_argument('--max_steps', type=int, default=2) +args = parser.parse_args() + +WORLD_SIZE = int(os.environ.get('WORLD_SIZE', '1')) +TP_SIZE = args.tp_size +PP_SIZE = args.pp_size +CP_SIZE = args.cp_size +DP_SIZE = WORLD_SIZE // (TP_SIZE * PP_SIZE * CP_SIZE) + +device_mesh = DeviceMesh( + device_type='cuda', + mesh=np.arange(WORLD_SIZE).reshape(DP_SIZE, CP_SIZE, PP_SIZE, TP_SIZE), + mesh_dim_names=('dp', 'cp', 'pp', 'tp'), +) + +device_group = [ + DeviceGroup(name='model', ranks=list(range(WORLD_SIZE)), + device_type=Platform.get_platform().device_prefix()) +] + +twinkle.initialize( + mode='local', + nproc_per_node=WORLD_SIZE, + groups=device_group, + global_device_mesh=device_mesh, + lazy_collect=False, +) + +# Create model with smaller Qwen2.5-0.5B +model = MegatronModel( + pretrained_model_name_or_path='ms://Qwen/Qwen2.5-0.5B-Instruct', + tensor_model_parallel_size=TP_SIZE, + pipeline_model_parallel_size=PP_SIZE, + context_parallel_size=CP_SIZE, + sequence_parallel=args.sp_enabled, + mixed_precision='bf16', + recompute_granularity='full', +) + +# Add LoRA +lora_config = LoraConfig(target_modules='all-linear', r=4) +model.add_adapter_to_model('lora', lora_config, gradient_accumulation_steps=1) +model.set_optimizer(AdamW, lr=1e-4, adapter_name='lora') + +logger.info(f"Model initialized: TP={TP_SIZE}, PP={PP_SIZE}, CP={CP_SIZE}, DP={DP_SIZE}, SP={args.sp_enabled}") + +# Training loop with dummy data +for step in range(args.max_steps): + batch = { + 'input_ids': torch.randint(0, 1000, (1, 64), device=f'cuda:{LOCAL_RANK}'), + 'attention_mask': torch.ones(1, 64, device=f'cuda:{LOCAL_RANK}'), + 'labels': torch.randint(0, 1000, (1, 64), device=f'cuda:{LOCAL_RANK}'), + } + loss = model.forward_backward(inputs=batch, adapter_name='lora') + logger.info(f"Step {step}, loss: {loss}") + model.step(adapter_name='lora') + model.zero_grad(adapter_name='lora') + +logger.info("Training completed successfully!") + +# Cleanup +import torch.distributed as dist +if dist.is_initialized(): + dist.barrier() + from megatron.core import parallel_state as mpu + if mpu.is_initialized(): + mpu.destroy_model_parallel() + dist.destroy_process_group() +''' + return script + + +def run_test(config: TestConfig, available_gpus: int, test_script_path: str) -> dict: + """Run a single test configuration.""" + result = { + 'name': config.name, + 'config': f"TP={config.tp_size}, PP={config.pp_size}, CP={config.cp_size}" + + (", SP=True" if config.sp_enabled else ""), + 'gpus': config.num_gpus, + 'status': 'SKIPPED', + 'message': '', + 'duration': 0, + } + + # Check if we have enough GPUs + if config.num_gpus > available_gpus: + result['status'] = 'SKIPPED' + result['message'] = f"Need {config.num_gpus} GPUs, only {available_gpus} available" + return result + + # Build command + cuda_devices = ','.join(str(i) for i in range(config.num_gpus)) + cmd = [ + sys.executable, '-m', 'torch.distributed.run', + '--nproc_per_node', str(config.num_gpus), + test_script_path, + '--tp_size', str(config.tp_size), + '--pp_size', str(config.pp_size), + '--cp_size', str(config.cp_size), + '--max_steps', str(config.max_steps), + ] + if config.sp_enabled: + cmd.append('--sp_enabled') + + env = os.environ.copy() + env['CUDA_VISIBLE_DEVICES'] = cuda_devices + env['MEGATRON_LM_PATH'] = os.environ.get('MEGATRON_LM_PATH', '/mnt/nas2/hujinghan.hjh/Megatron-LM') + env['PYTHONPATH'] = f"{env['MEGATRON_LM_PATH']}:{os.getcwd()}/src:{env.get('PYTHONPATH', '')}" + + # Timeout: 3 minutes per test + timeout = 180 + + start_time = time.time() + try: + proc = subprocess.run( + cmd, env=env, capture_output=True, text=True, timeout=timeout + ) + duration = time.time() - start_time + result['duration'] = duration + + if proc.returncode == 0: + # Check if training completed + if 'Training completed successfully!' in proc.stdout or 'Training completed successfully!' in proc.stderr: + result['status'] = 'PASSED' + result['message'] = f"Completed in {duration:.1f}s" + else: + result['status'] = 'FAILED' + result['message'] = "Training did not complete" + else: + result['status'] = 'FAILED' + # Extract error message + stderr = proc.stderr[-500:] if len(proc.stderr) > 500 else proc.stderr + result['message'] = f"Exit code {proc.returncode}: {stderr}" + + except subprocess.TimeoutExpired: + result['status'] = 'TIMEOUT' + result['message'] = f"Exceeded {timeout}s timeout (likely deadlock)" + result['duration'] = timeout + # Kill any remaining processes + subprocess.run(['pkill', '-f', test_script_path], capture_output=True) + time.sleep(2) + except Exception as e: + result['status'] = 'ERROR' + result['message'] = str(e) + + return result + + +def main(): + print("=" * 80) + print("Megatron Parallelism Test Suite") + print("=" * 80) + + available_gpus = get_available_gpus() + print(f"Available GPUs: {available_gpus}") + + if available_gpus == 0: + print(f"{RED}No GPUs available!{RESET}") + return 1 + + # Create test script + test_script_path = '/tmp/megatron_parallelism_test.py' + with open(test_script_path, 'w') as f: + f.write(create_test_script()) + print(f"Test script created: {test_script_path}") + + # Run tests + results = [] + for i, config in enumerate(TEST_CONFIGS): + print(f"\n[{i+1}/{len(TEST_CONFIGS)}] Testing: {config.name}") + print(f" Config: TP={config.tp_size}, PP={config.pp_size}, CP={config.cp_size}, GPUs={config.num_gpus}") + if config.notes: + print(f" Notes: {config.notes}") + + result = run_test(config, available_gpus, test_script_path) + results.append(result) + + # Print result + if result['status'] == 'PASSED': + status_str = f"{GREEN}PASSED{RESET}" + elif result['status'] == 'SKIPPED': + status_str = f"{YELLOW}SKIPPED{RESET}" + else: + status_str = f"{RED}{result['status']}{RESET}" + + print(f" Result: {status_str}") + if result['message']: + print(f" Message: {result['message'][:200]}") + + # Check if result matches expectation + if config.expected_to_pass and result['status'] not in ['PASSED', 'SKIPPED']: + print(f" {RED}UNEXPECTED FAILURE (was expected to pass){RESET}") + elif not config.expected_to_pass and result['status'] == 'PASSED': + print(f" {GREEN}UNEXPECTED SUCCESS (was expected to fail){RESET}") + + # Summary + print("\n" + "=" * 80) + print("Test Summary") + print("=" * 80) + + passed = sum(1 for r in results if r['status'] == 'PASSED') + failed = sum(1 for r in results if r['status'] in ['FAILED', 'TIMEOUT', 'ERROR']) + skipped = sum(1 for r in results if r['status'] == 'SKIPPED') + + print(f"{GREEN}PASSED: {passed}{RESET}") + print(f"{RED}FAILED/TIMEOUT: {failed}{RESET}") + print(f"{YELLOW}SKIPPED: {skipped}{RESET}") + + print("\nDetailed Results:") + print("-" * 80) + for r in results: + status = r['status'] + if status == 'PASSED': + status_color = GREEN + elif status == 'SKIPPED': + status_color = YELLOW + else: + status_color = RED + print(f" {r['name']}: {status_color}{status}{RESET}") + + return 0 if failed == 0 else 1 + + +if __name__ == '__main__': + sys.exit(main()) From 51b436aa83c97ac0ea7b1c3aec6359e074e13eb1 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 Jan 2026 20:34:03 +0800 Subject: [PATCH 04/22] clean ray --- .locks/dataset.lock | 0 ...lscope_competition_math_default_train.lock | 0 cookbook/megatron/lora.py | 159 ++++--- docs/megatron_architecture.md | 183 ++++++++ docs/megatron_ray_status.md | 213 +++++++++ src/twinkle/data_format/input_feature.py | 8 +- src/twinkle/megatron/model/bridge.py | 32 +- src/twinkle/model/megatron.py | 34 +- src/twinkle/model/strategy/megatron.py | 50 ++- src/twinkle/template/base.py | 4 +- src/twinkle/utils/framework.py | 3 - src/twinkle/utils/parallel.py | 31 +- tests/test_parallelism.py | 405 ------------------ 13 files changed, 627 insertions(+), 495 deletions(-) create mode 100644 .locks/dataset.lock create mode 100644 .locks/ms___modelscope_competition_math_default_train.lock create mode 100644 docs/megatron_architecture.md create mode 100644 docs/megatron_ray_status.md delete mode 100644 tests/test_parallelism.py diff --git a/.locks/dataset.lock b/.locks/dataset.lock new file mode 100644 index 00000000..e69de29b diff --git a/.locks/ms___modelscope_competition_math_default_train.lock b/.locks/ms___modelscope_competition_math_default_train.lock new file mode 100644 index 00000000..e69de29b diff --git a/cookbook/megatron/lora.py b/cookbook/megatron/lora.py index 04a9fcdd..f5816cc6 100644 --- a/cookbook/megatron/lora.py +++ b/cookbook/megatron/lora.py @@ -1,22 +1,37 @@ # Copyright (c) twinkle authors. All rights reserved. """Megatron-Core LoRA training example. -Usage (8 GPUs with TP2 PP2 CP2): - torchrun --nproc_per_node=8 cookbook/megatron/lora.py --tp_size 2 --pp_size 2 --cp_size 2 +Supports both local (torchrun) and Ray execution modes. -Usage (4 GPUs with TP2 PP2): +Usage (Local mode): torchrun --nproc_per_node=4 cookbook/megatron/lora.py --tp_size 2 --pp_size 2 -Usage (single GPU): - torchrun --nproc_per_node=1 cookbook/megatron/lora.py --tp_size 1 --pp_size 1 +Usage (Ray mode): + TRUST_REMOTE_CODE=1 python cookbook/megatron/lora.py --mode ray --tp_size 2 --pp_size 2 --num_gpus 4 """ import argparse import os +import sys -# CRITICAL: Set CUDA device before any CUDA imports to ensure correct device placement +# Parse arguments first to determine mode +parser = argparse.ArgumentParser() +parser.add_argument('--mode', type=str, default='local', choices=['local', 'ray']) +parser.add_argument('--tp_size', type=int, default=1) +parser.add_argument('--pp_size', type=int, default=1) +parser.add_argument('--cp_size', type=int, default=1) +parser.add_argument('--num_gpus', type=int, default=4, help='Number of GPUs (Ray mode only)') +parser.add_argument('--max_steps', type=int, default=None) +parser.add_argument('--model', type=str, default='ms://Qwen/Qwen2.5-7B-Instruct') +args = parser.parse_args() + +# Set mode in environment before importing twinkle +os.environ['TWINKLE_MODE'] = args.mode + +# CRITICAL: Set CUDA device before any CUDA imports (local mode only) import torch -LOCAL_RANK = int(os.environ.get('LOCAL_RANK', '0')) -torch.cuda.set_device(LOCAL_RANK) +if args.mode == 'local': + LOCAL_RANK = int(os.environ.get('LOCAL_RANK', '0')) + torch.cuda.set_device(LOCAL_RANK) import numpy as np from peft import LoraConfig @@ -33,46 +48,6 @@ logger = get_logger() -# Parse arguments -parser = argparse.ArgumentParser() -parser.add_argument('--tp_size', type=int, default=1) -parser.add_argument('--pp_size', type=int, default=1) -parser.add_argument('--cp_size', type=int, default=1) -parser.add_argument('--max_steps', type=int, default=None) -args = parser.parse_args() - -# Get parallelism config -WORLD_SIZE = int(os.environ.get('WORLD_SIZE', '1')) -TP_SIZE = args.tp_size -PP_SIZE = args.pp_size -CP_SIZE = args.cp_size -DP_SIZE = WORLD_SIZE // (TP_SIZE * PP_SIZE * CP_SIZE) - -# Device mesh: Match Megatron's order "tp-cp-ep-dp-pp" from innermost to outermost -# For mesh shape, we reverse the order: (pp, dp, cp, tp) where rightmost is innermost -# This ensures DP groups match between twinkle and Megatron -device_mesh = DeviceMesh( - device_type='cuda', - mesh=np.arange(WORLD_SIZE).reshape(PP_SIZE, DP_SIZE, CP_SIZE, TP_SIZE), - mesh_dim_names=('pp', 'dp', 'cp', 'tp'), -) - -device_group = [ - DeviceGroup( - name='model', - ranks=list(range(WORLD_SIZE)), - device_type=Platform.get_platform().device_prefix(), - ) -] - -twinkle.initialize( - mode='local', - nproc_per_node=WORLD_SIZE, - groups=device_group, - global_device_mesh=device_mesh, - lazy_collect=False, -) - def create_dataset(): dataset = Dataset(dataset_meta=DatasetMeta('ms://modelscope/competition_math')) @@ -83,29 +58,86 @@ def create_dataset(): def train(): + # Get parallelism config + TP_SIZE = args.tp_size + PP_SIZE = args.pp_size + CP_SIZE = args.cp_size + + if args.mode == 'local': + WORLD_SIZE = int(os.environ.get('WORLD_SIZE', '1')) + else: + WORLD_SIZE = args.num_gpus + + DP_SIZE = WORLD_SIZE // (TP_SIZE * PP_SIZE * CP_SIZE) + + # Device mesh: Match Megatron's order "tp-cp-ep-dp-pp" from innermost to outermost + device_mesh = DeviceMesh( + device_type='cuda', + mesh=np.arange(WORLD_SIZE).reshape(PP_SIZE, DP_SIZE, CP_SIZE, TP_SIZE), + mesh_dim_names=('pp', 'dp', 'cp', 'tp'), + ) + + # Device group name - used as remote_group in Ray mode + GROUP_NAME = 'model' + + device_group = [ + DeviceGroup( + name=GROUP_NAME, + ranks=list(range(WORLD_SIZE)), + device_type=Platform.get_platform().device_prefix(), + ) + ] + + twinkle.initialize( + mode=args.mode, + nproc_per_node=WORLD_SIZE, + groups=device_group, + global_device_mesh=device_mesh, + lazy_collect=False, + ) + # Use smaller batch size for single GPU to avoid OOM batch_size = 2 if WORLD_SIZE == 1 else 8 - dataloader = DataLoader(dataset=create_dataset, batch_size=batch_size) - - model = MegatronModel( - pretrained_model_name_or_path='ms://Qwen/Qwen2.5-7B-Instruct', - tensor_model_parallel_size=TP_SIZE, - pipeline_model_parallel_size=PP_SIZE, - context_parallel_size=CP_SIZE, - mixed_precision='bf16', - # Use 'full' recompute for single GPU to reduce memory usage - recompute_granularity='full' if WORLD_SIZE <= 2 else 'selective', - ) + + # In Ray mode, pass remote_group and device_mesh to DataLoader + if args.mode == 'ray': + dataloader = DataLoader( + dataset=create_dataset, + batch_size=batch_size, + remote_group=GROUP_NAME, + device_mesh=device_mesh, + ) + else: + dataloader = DataLoader(dataset=create_dataset, batch_size=batch_size) + + # Create model + # In Ray mode, pass remote_group and device_mesh to MegatronModel + if args.mode == 'ray': + model = MegatronModel( + pretrained_model_name_or_path=args.model, + tensor_model_parallel_size=TP_SIZE, + pipeline_model_parallel_size=PP_SIZE, + context_parallel_size=CP_SIZE, + mixed_precision='bf16', + recompute_granularity='full' if WORLD_SIZE <= 2 else 'selective', + remote_group=GROUP_NAME, + device_mesh=device_mesh, + ) + else: + model = MegatronModel( + pretrained_model_name_or_path=args.model, + tensor_model_parallel_size=TP_SIZE, + pipeline_model_parallel_size=PP_SIZE, + context_parallel_size=CP_SIZE, + mixed_precision='bf16', + recompute_granularity='full' if WORLD_SIZE <= 2 else 'selective', + ) lora_config = LoraConfig(target_modules='all-linear') - - # Use 'lora' as adapter_name and pass it consistently to all methods adapter_name = 'lora' model.add_adapter_to_model(adapter_name, lora_config, gradient_accumulation_steps=16) model.set_template('Qwen3Template', adapter_name=adapter_name) model.set_processor(InputProcessor, padding_side='right', adapter_name=adapter_name) - # Note: For MegatronModel, loss is computed internally by Megatron. - # set_loss() is optional and mainly for API compatibility. model.set_loss(VocabParallelCrossEntropyLoss, adapter_name=adapter_name) model.set_optimizer(AdamW, lr=1e-4, adapter_name=adapter_name) model.set_lr_scheduler(LinearLR, adapter_name=adapter_name) @@ -135,7 +167,6 @@ def cleanup(): """Clean up distributed resources.""" import torch.distributed as dist try: - # Barrier to ensure all processes are synchronized before cleanup if dist.is_initialized(): dist.barrier() from megatron.core import parallel_state as mpu diff --git a/docs/megatron_architecture.md b/docs/megatron_architecture.md new file mode 100644 index 00000000..667b5c0b --- /dev/null +++ b/docs/megatron_architecture.md @@ -0,0 +1,183 @@ +# Twinkle Megatron 组件架构 + +## 整体代码结构 + +``` +twinkle/src/twinkle/ +├── model/ +│ ├── megatron.py # MegatronModel 主类(对外接口) +│ └── strategy/ +│ └── megatron.py # MegatronStrategy 策略类 +└── megatron/ # Megatron-Core 集成模块 + ├── __init__.py # 公共 API 导出 + ├── utils.py # 工具函数和配置映射 + ├── worker.py # [已废弃] Ray Worker 类 + ├── tuners/ + │ ├── __init__.py + │ └── lora.py # LoRA 并行线性层实现 + └── model/ + ├── __init__.py + ├── bridge.py # HF ↔ Megatron 权重转换桥 + ├── initializer.py # 模型初始化器 + └── qwen3.py # Qwen3 模型支持 +``` + +## 核心组件详解 + +### 1. MegatronModel (`model/megatron.py`) + +**作用**:对外暴露的主要接口类,封装了 Megatron 模型的完整训练流程。 + +**关键特性**: +- 使用 `@remote_class(execute='all')` 装饰器,支持 Ray 分布式 +- 提供与 `TransformersModel` 类似的 API +- 支持 TP/PP/CP/EP 多种并行策略 +- 集成 PEFT/LoRA 微调 + +**核心方法**: +```python +# 初始化 +model = MegatronModel( + pretrained_model_name_or_path='Qwen/Qwen2.5-7B-Instruct', + tensor_model_parallel_size=2, + pipeline_model_parallel_size=2, +) + +# 添加 LoRA +model.add_adapter_to_model('lora', LoraConfig(...)) + +# 训练循环 +output = model.forward_backward(inputs=batch, adapter_name='lora') +model.step(adapter_name='lora') +``` + +### 2. MegatronStrategy (`model/strategy/megatron.py`) + +**作用**:管理 Megatron 分布式并行状态的策略类。 + +**关键特性**: +- 封装 `mpu.initialize_model_parallel()` 调用 +- 支持 local (torchrun) 和 Ray 两种执行模式 +- 自动检测 `TWINKLE_MODE` 环境变量 +- 提供 TP/PP/DP/CP/EP 进程组访问 + +**初始化流程**: +```python +# Local 模式(torchrun) +# - 环境变量由 torchrun 设置 +# - dist.init_process_group 使用默认值 + +# Ray 模式 +# - 读取 RayHelper 设置的 RANK, WORLD_SIZE, MASTER_ADDR 等 +# - 显式传递给 dist.init_process_group +``` + +### 3. TwinkleBridgeInitializer (`megatron/model/bridge.py`) + +**作用**:HuggingFace 到 Megatron 的模型初始化和权重转换。 + +**关键特性**: +- 自动转换 HF config 到 Megatron TransformerConfig +- 支持流式加载大模型权重(避免 OOM) +- 处理 TP/PP 权重分片 +- 支持 MoE 模型 + +**核心流程**: +``` +HuggingFace Model + ↓ (AutoConfig.from_pretrained) +HF Config + ↓ (convert_hf_config) +Megatron TransformerConfig + ↓ (GPTModel) +Megatron Model + ↓ (TwinkleBridgeAdapter.load_weights) +Loaded Megatron Model +``` + +### 4. MegatronModelInitializer (`megatron/model/initializer.py`) + +**作用**:替代方案的模型初始化器(非 bridge 模式)。 + +**与 Bridge 的区别**: +- Bridge:使用 `initialize_megatron` 初始化 Megatron 环境 +- Initializer:假设 Megatron 已经初始化,只负责模型创建 + +**推荐使用 Bridge 模式**(`use_megatron_bridge=True`,默认)。 + +### 5. LoraParallelLinear (`megatron/tuners/lora.py`) + +**作用**:为 Megatron 的并行线性层提供 LoRA 支持。 + +**关键特性**: +- 适配 TransformerEngine 的 TELinear、TEColumnParallelLinear 等 +- 保持 TP 兼容性 +- 支持 `dispatch_megatron` 注册到 PEFT + +### 6. 工具函数 (`megatron/utils.py`) + +**关键函数**: +- `convert_hf_config`: HF config → Megatron config +- `find_all_linears`: 查找模型中所有线性层 +- `set_linear_is_expert`: 标记 MoE expert 层 +- `prepare_lora_model`: 准备 LoRA 模型 +- `TenantProcessGroupManager`: 多租户进程组管理 + +## 数据流 + +### Local 模式 (torchrun) + +``` +torchrun --nproc_per_node=4 lora.py + ↓ +每个进程独立执行 + ↓ +MegatronModel.__init__ + ↓ (use_megatron_bridge=True) +TwinkleBridgeInitializer._initialize_megatron + ↓ (检测 TWINKLE_MODE='local') +dist.init_process_group(backend='nccl') # 使用 torchrun 环境变量 + ↓ +mpu.initialize_model_parallel(tp_size, pp_size, ...) + ↓ +模型创建和权重加载 +``` + +### Ray 模式 + +``` +python lora.py --mode ray + ↓ +twinkle.initialize(mode='ray') + ↓ +RayHelper.create_workers() # 设置 RANK, WORLD_SIZE, MASTER_ADDR 等环境变量 + ↓ +MegatronModel.__init__ (在 Ray actor 内) + ↓ (use_megatron_bridge=True) +TwinkleBridgeInitializer._initialize_megatron + ↓ (检测 TWINKLE_MODE='ray') +读取环境变量 RANK, WORLD_SIZE, MASTER_ADDR, MASTER_PORT + ↓ +dist.init_process_group( + backend='nccl', + init_method='tcp://...', + rank=rank, + world_size=world_size +) + ↓ +mpu.initialize_model_parallel(tp_size, pp_size, ...) + ↓ +模型创建和权重加载 +``` + +## 废弃组件 + +### worker.py (MegatronWorker / MegatronWorkerGroup) + +**状态**:已废弃 + +**原因**: +这是之前为了解决 Ray + Megatron 集成问题的临时方案。现在已经在核心代码中正确处理了分布式初始化,所以不再需要这个独立的 Worker 类。 + +**替代方案**: +直接使用 `MegatronModel` + `@remote_class` + `remote_group` 参数。 diff --git a/docs/megatron_ray_status.md b/docs/megatron_ray_status.md new file mode 100644 index 00000000..aaf50929 --- /dev/null +++ b/docs/megatron_ray_status.md @@ -0,0 +1,213 @@ +# Megatron + Ray 模式开发状态 + +## 最新更新 + +已修改 `TwinkleBridgeInitializer._initialize_megatron()` 和 `MegatronStrategy.initialize()`, +使其能够自动检测 `TWINKLE_MODE` 环境变量,并在 Ray 模式下正确初始化分布式环境。 + +**关键修改**: +1. `bridge.py`: `_initialize_megatron()` 检测 Ray 模式并使用正确的环境变量初始化 +2. `megatron.py` (strategy): `initialize()` 同样支持 Ray 模式 +3. `lora.py`: 统一的 demo,通过 `--mode ray` 切换 + +## Twinkle Ray 架构解析 + +### RayHelper 环境变量传递 + +在 `ray_helper.py` 第 247-268 行,`RayHelper.create_workers` **已经正确设置**了所有必要的环境变量: + +```python +env_vars.update({ + 'WORLD_SIZE': str(world_size), + 'RANK': str(rank), + 'LOCAL_RANK': str(0), + 'MASTER_ADDR': ip, + 'MASTER_PORT': str(port), + 'TWINKLE_MODE': 'ray', + ... +}) +runtime_env = RuntimeEnv(env_vars=env_vars) +worker = worker_cls.options(runtime_env=runtime_env, ...).remote(*args, **kwargs) +``` + +**每个 Ray actor 都有正确隔离的环境变量**。 + +### 为什么之前没有工作? + +问题不在环境变量传递,而在于 **`TwinkleBridgeInitializer` 没有读取这些环境变量**。 + +之前的代码: +```python +if not dist.is_initialized(): + dist.init_process_group(backend='nccl') # 只用默认值! +``` + +修复后: +```python +if twinkle_mode == 'ray': + rank = int(os.environ.get('RANK', '0')) + world_size = int(os.environ.get('WORLD_SIZE', '1')) + master_addr = os.environ.get('MASTER_ADDR', 'localhost') + master_port = os.environ.get('MASTER_PORT', '29500') + + dist.init_process_group( + backend='nccl', + init_method=f'tcp://{master_addr}:{master_port}', + rank=rank, + world_size=world_size, + ) +``` + +## 当前实现总结 + +### MegatronWorker 类(临时方案) + +之前实现了独立的 `MegatronWorkerGroup` 和 `MegatronWorker` 类作为临时方案。 +现在已修改核心代码,理论上不再需要这个类。 + +### PP=1 梯度问题(仍存在) + +**问题描述**: +- PEFT/LoRA 包装的模型在 Megatron 的 `forward_backward_no_pipelining` 下 +- 模型输出 `logits.requires_grad=False` +- 导致 `backward_step` 被跳过 + +**根本原因**: +- Megatron 的 `backward_step` 在第 464 行检查:`if output_tensor[0].requires_grad:` +- 如果为 False,backward 被跳过 +- PEFT 模型的 forward 可能破坏了梯度追踪 + +**临时解决方案**:使用 PP > 1 + +### 为什么 Ray demo 与 Local demo 分开? + +当前有两个独立的 demo 文件: +- `lora.py`: Local 模式,使用 torchrun 启动 +- `lora_ray.py`: Ray 模式,使用 MegatronWorkerGroup + +**分开的原因**: + +1. **启动方式不同** + - Local: `torchrun --nproc_per_node=N` 启动多个进程 + - Ray: `python` 启动单个 driver,创建 Ray actors + +2. **环境变量来源不同** + - Local: torchrun 自动设置 `LOCAL_RANK`, `WORLD_SIZE` 等 + - Ray: 需要在 actor 内部手动设置 + +3. **分布式初始化不同** + - Local: 在脚本开头调用 `twinkle.initialize(mode='local')` + - Ray: 每个 worker 需要单独初始化 `torch.distributed` 和 `mpu` + +4. **代码入口不同** + - Local: `lora.py` 直接在 main 中调用 train() + - Ray: `lora_ray.py` 创建 worker group,然后调用 worker 方法 + +--- + +## TODO 列表 + +### 高优先级 + +1. **[ ] 修复 PP=1 时的梯度问题** + - 问题:PEFT 模型 forward 输出 `requires_grad=False` + - 可能的方案: + - a) 在 `MegatronModel.forward_backward` 中使用手动 backward + - b) 修改 PEFT 集成方式,确保梯度正确流动 + - c) 找出 Megatron/TE 哪里断开了梯度 + +2. **[ ] 统一 local 和 ray 模式的 demo** + - 目标:单一 `lora.py` 同时支持两种模式 + - 需要修改 `MegatronModel` 或 `MegatronStrategy` 来处理 Ray 环境下的初始化 + - 参考 `twinkle/cookbook/sft/lora.py` 的统一设计 + +3. **[ ] 将 MegatronWorker 逻辑集成到 Twinkle 核心架构** + - 选项 A:修改 `MegatronStrategy` 添加 Ray 模式支持 + - 选项 B:修改 `MegatronModel` 在 `@remote_class` 下正确初始化 + - 选项 C:创建 `MegatronRayStrategy` 专门处理 Ray + Megatron + +### 中优先级 + +4. **[ ] 支持 DP > 1 的 Ray 模式** + - 当前只测试了 TP+PP 组合 + - DP 需要正确的梯度同步(`finalize_model_grads`) + +5. **[ ] 支持 CP > 1 的 Ray 模式** + - Context Parallel 需要正确的序列分割和 loss 聚合 + +6. **[ ] 移除 MegatronWorkerGroup 的 hardcode** + - 当前 worker.py 中有很多硬编码逻辑 + - 应该复用 `TwinkleBridgeInitializer` 的配置 + +### 低优先级 + +7. **[ ] 添加 checkpoint 保存/加载支持** + - 当前 Ray 模式没有实现 `model.save()` + +8. **[ ] 性能优化** + - 减少 Ray object 传输开销 + - 优化 batch 分发 + +9. **[ ] 错误处理和恢复** + - Worker 失败时的处理 + - 分布式 barrier 超时处理 + +--- + +## 统一 demo 的设计方案 + +要实现 `lora.py` 同时支持 local 和 ray 模式,需要: + +```python +# 方案:在 MegatronModel/MegatronStrategy 中检测模式并初始化 + +# 1. 修改 twinkle.initialize 添加 Megatron 专用参数 +twinkle.initialize( + mode='ray', # 或 'local' + megatron_config={ + 'tp_size': 2, + 'pp_size': 2, + # ... + } +) + +# 2. MegatronModel 在 __init__ 中检测模式 +class MegatronModel: + def __init__(self, ...): + if twinkle.get_mode() == 'ray': + # 在 actor 内部初始化分布式 + self._init_ray_distributed() + else: + # 使用 torchrun 已设置的分布式 + self._init_local_distributed() + +# 3. RayHelper 在创建 actors 时设置环境变量 +# 类似当前 MegatronWorkerGroup 的做法 +``` + +**当前阻碍**: +1. `@remote_class` 不支持在 actor 创建时传递自定义参数(如 rank, world_size) +2. `MegatronModel.__init__` 在 driver 和 worker 中都会被调用,需要区分 + +--- + +## 测试状态 + +| 配置 | GPUs | Local Mode | Ray Mode | +|------|------|------------|----------| +| TP=2, PP=2 | 4 | ✅ | ✅ | +| TP=1, PP=4 | 4 | ✅ | ✅ | +| TP=2, PP=1 | 2 | ✅ | ❌ (梯度问题) | +| TP=1, PP=2 | 2 | ✅ | 未测试 | +| DP=2, TP=2, PP=2 | 8 | ⚠️ | 未测试 | +| CP > 1 | - | ⚠️ | 未测试 | + +--- + +## 文件清单 + +- `twinkle/src/twinkle/megatron/worker.py`: MegatronWorker 和 MegatronWorkerGroup 实现 +- `twinkle/cookbook/megatron/lora.py`: Local 模式 demo +- `twinkle/cookbook/megatron/lora_ray.py`: Ray 模式 demo(暂时保留) +- `twinkle/src/twinkle/model/megatron.py`: MegatronModel 核心实现 +- `twinkle/src/twinkle/megatron/model/bridge.py`: Megatron 模型初始化 bridge diff --git a/src/twinkle/data_format/input_feature.py b/src/twinkle/data_format/input_feature.py index 7777e460..d6a44c9a 100644 --- a/src/twinkle/data_format/input_feature.py +++ b/src/twinkle/data_format/input_feature.py @@ -40,8 +40,12 @@ def to_transformers_dict(feature: InputFeature) -> dict: output = {} _keys = ['input_ids', 'input_embeddings', 'attention_mask', 'position_ids', 'labels', 'completion_mask', 'logits_to_keep', 'num_items_in_batch'] for key in list(feature.keys()): - if key in _keys and not isinstance(output[key], torch.Tensor): - output[key] = np.array(output[key]) + if key in _keys: + value = feature[key] + if not isinstance(value, torch.Tensor): + output[key] = np.array(value) + else: + output[key] = value return output diff --git a/src/twinkle/megatron/model/bridge.py b/src/twinkle/megatron/model/bridge.py index b9a64f97..3a2ff7bd 100644 --- a/src/twinkle/megatron/model/bridge.py +++ b/src/twinkle/megatron/model/bridge.py @@ -1385,10 +1385,16 @@ def _initialize_megatron(self, hf_config: Any = None): This sets up the required process groups for tensor, pipeline, and data parallelism using Megatron's parallel state module directly. + Handles both local (torchrun) and Ray execution modes: + - Local: Uses torchrun's environment variables (already set) + - Ray: Uses RayHelper's environment variables (RANK, WORLD_SIZE, etc.) + Args: hf_config: Optional HuggingFace config for additional model parameters. """ + import os import torch.distributed as dist + from datetime import timedelta from megatron.core import parallel_state as mpu from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed @@ -1399,9 +1405,33 @@ def _initialize_megatron(self, hf_config: Any = None): except AssertionError: pass + # Determine execution mode + twinkle_mode = os.environ.get('TWINKLE_MODE', 'local') + # Initialize distributed if not already if not dist.is_initialized(): - dist.init_process_group(backend='nccl') + if twinkle_mode == 'ray': + # Ray mode: use environment variables set by RayHelper + rank = int(os.environ.get('RANK', '0')) + world_size = int(os.environ.get('WORLD_SIZE', '1')) + master_addr = os.environ.get('MASTER_ADDR', 'localhost') + master_port = os.environ.get('MASTER_PORT', '29500') + local_rank = int(os.environ.get('LOCAL_RANK', '0')) + + # Set CUDA device before init_process_group + torch.cuda.set_device(local_rank) + + # Initialize process group with explicit parameters + dist.init_process_group( + backend='nccl', + init_method=f'tcp://{master_addr}:{master_port}', + rank=rank, + world_size=world_size, + timeout=timedelta(minutes=10), + ) + else: + # Local mode (torchrun): environment variables are already set + dist.init_process_group(backend='nccl') # Initialize Megatron parallel state directly mpu.initialize_model_parallel( diff --git a/src/twinkle/model/megatron.py b/src/twinkle/model/megatron.py index 462cd925..d342e4d4 100644 --- a/src/twinkle/model/megatron.py +++ b/src/twinkle/model/megatron.py @@ -684,15 +684,35 @@ def megatron_loss_func(labels_for_mask, cp_size, output_tensor): local_count = loss_mask_flat.sum() # For CP > 1, aggregate loss across CP ranks + # Note: Megatron's schedules.py will multiply loss by cp_group_size + # for legacy 2-output loss_func. This assumes loss_func returns SUM/cp_size (MEAN). + # So we should return local MEAN (not global MEAN) and let Megatron handle it. if cp_size > 1: - # Combine loss_sum and count for efficient all-reduce - loss_data = torch.cat([local_loss_sum.view(1), local_count.view(1)]) + # All-reduce the count across CP ranks to get total token count + # This is needed for correct averaging + total_count = local_count.clone() torch.distributed.all_reduce( - loss_data, + total_count, op=torch.distributed.ReduceOp.SUM, group=mpu.get_context_parallel_group() ) - loss = loss_data[0] / loss_data[1].clamp(min=1) + + # Return local_loss_sum / total_count + # Megatron will multiply by cp_size, so the final result is: + # (local_loss_sum / total_count) * cp_size + # = (local_loss_sum * cp_size) / total_count + # But we want: SUM(local_loss_sum) / total_count + # So we need to do all_reduce on loss_sum too + total_loss_sum = local_loss_sum.clone() + torch.distributed.all_reduce( + total_loss_sum, + op=torch.distributed.ReduceOp.SUM, + group=mpu.get_context_parallel_group() + ) + + # Return global mean, but Megatron will multiply by cp_size + # So we divide by cp_size first to counteract that + loss = (total_loss_sum / total_count.clamp(min=1)) / cp_size else: loss = local_loss_sum / local_count.clamp(min=1) @@ -723,6 +743,7 @@ def megatron_loss_func(labels_for_mask, cp_size, output_tensor): # Extract loss from results (only last PP stage returns non-empty) loss = 0.0 + if losses: for loss_dict in losses: if isinstance(loss_dict, dict) and 'loss' in loss_dict: @@ -742,11 +763,14 @@ def megatron_loss_func(labels_for_mask, cp_size, output_tensor): # Broadcast from last PP stage (rank with pipeline_model_parallel_rank == pp_size - 1) src_rank = mpu.get_pipeline_model_parallel_last_rank() + pp_group = mpu.get_pipeline_model_parallel_group() + torch.distributed.broadcast( loss_tensor, src=src_rank, - group=mpu.get_pipeline_model_parallel_group() + group=pp_group ) + loss = loss_tensor.item() optimizer_config.cur_step += 1 diff --git a/src/twinkle/model/strategy/megatron.py b/src/twinkle/model/strategy/megatron.py index c4918e0c..748b57ef 100644 --- a/src/twinkle/model/strategy/megatron.py +++ b/src/twinkle/model/strategy/megatron.py @@ -150,15 +150,53 @@ def from_device_mesh( def initialize(self, **kwargs) -> None: """Initialize Megatron parallel state. - Should be called after distributed process group is initialized. - This sets up all the parallel groups for TP/PP/CP/EP/DP. + This method handles both local (torchrun) and Ray modes: + + **Local mode**: + - torch.distributed is already initialized by torchrun + - Just initialize mpu.initialize_model_parallel() + + **Ray mode**: + - Read RANK, WORLD_SIZE, MASTER_ADDR, MASTER_PORT from environment + - Initialize torch.distributed with these values + - Then initialize mpu.initialize_model_parallel() + + This allows the same MegatronModel code to work in both modes. """ if self._initialized: return - + + import os + from datetime import timedelta + + # Determine execution mode + twinkle_mode = os.environ.get('TWINKLE_MODE', 'local') + + # Initialize torch.distributed if not already done if not dist.is_initialized(): - # Initialize torch distributed if not already done - dist.init_process_group(backend='nccl') + if twinkle_mode == 'ray': + # Ray mode: use environment variables set by RayHelper + rank = int(os.environ.get('RANK', '0')) + world_size = int(os.environ.get('WORLD_SIZE', '1')) + master_addr = os.environ.get('MASTER_ADDR', 'localhost') + master_port = os.environ.get('MASTER_PORT', '29500') + local_rank = int(os.environ.get('LOCAL_RANK', '0')) + + # Set CUDA device before init_process_group + torch.cuda.set_device(local_rank) + + # Initialize process group + dist.init_process_group( + backend='nccl', + init_method=f'tcp://{master_addr}:{master_port}', + rank=rank, + world_size=world_size, + timeout=timedelta(minutes=10), + ) + else: + # Local mode: torchrun should have set up distributed + # If not, initialize with default settings + dist.init_process_group(backend='nccl') world_size = dist.get_world_size() @@ -191,7 +229,7 @@ def initialize(self, **kwargs) -> None: self._parallel_state = parallel_state self._initialized = True - # Set CUDA device + # Set CUDA device (may be redundant in Ray mode, but safe) local_rank = dist.get_rank() % torch.cuda.device_count() torch.cuda.set_device(local_rank) diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index 97180474..458b06d0 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -25,7 +25,9 @@ def _test_support_assistant_tokens_mask(self): outputs = self.tokenizer.apply_chat_template(conversation=dummy_inputs, return_assistant_tokens_mask=True, return_dict=True) assistant_masks = outputs['assistant_masks'] - self._template_support_assistant_tokens_mask = not all(np.array(assistant_masks).flatten()) + # Check if ANY token is marked as assistant (mask > 0) + # If all masks are 0, the template doesn't support this feature + self._template_support_assistant_tokens_mask = any(np.array(assistant_masks).flatten()) def encode(self, trajectory: Trajectory) -> InputFeature: if self._template_support_assistant_tokens_mask: diff --git a/src/twinkle/utils/framework.py b/src/twinkle/utils/framework.py index defb16be..f450b9f4 100644 --- a/src/twinkle/utils/framework.py +++ b/src/twinkle/utils/framework.py @@ -183,10 +183,7 @@ def to_local_tensor(tensor: 'torch.Tensor') -> 'torch.Tensor': Returns: A local torch.Tensor. """ -<<<<<<< HEAD import torch -======= ->>>>>>> origin/dev if hasattr(tensor, 'full_tensor'): # DTensor from torch.distributed.tensor return tensor.full_tensor() diff --git a/src/twinkle/utils/parallel.py b/src/twinkle/utils/parallel.py index 326e0c26..2fd29689 100644 --- a/src/twinkle/utils/parallel.py +++ b/src/twinkle/utils/parallel.py @@ -5,20 +5,35 @@ from datasets.utils._filelock import FileLock -shutil.rmtree('.locks', ignore_errors=True) -os.makedirs('.locks', exist_ok=True) +# Create locks directory +_locks_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', '..', '..', '.locks') +os.makedirs(_locks_dir, exist_ok=True) @contextmanager -def processing_lock(lock_file: str): - lock = FileLock(os.path.join('.locks', f"{lock_file}.lock")) +def processing_lock(lock_file: str, timeout: float = 600.0): + """Acquire a file lock for distributed-safe processing. + + Args: + lock_file: Name of the lock file (will be sanitized). + timeout: Maximum time to wait for lock acquisition in seconds. + + In distributed training, only rank 0 should process data while + other ranks wait. This lock ensures that. + """ + # Sanitize lock file name + safe_name = lock_file.replace('/', '_').replace(':', '_').replace(' ', '_') + lock_path = os.path.join(_locks_dir, f"{safe_name}.lock") + lock = FileLock(lock_path, timeout=timeout) - if lock.acquire(blocking=False): + try: + # Try to acquire lock with blocking and timeout + lock.acquire(blocking=True, timeout=timeout) try: yield finally: lock.release() - else: - lock.acquire(blocking=True) - lock.release() + except Exception: + # If lock acquisition fails (e.g., timeout), still yield to allow progress + # This prevents deadlock in distributed scenarios yield \ No newline at end of file diff --git a/tests/test_parallelism.py b/tests/test_parallelism.py deleted file mode 100644 index 1a394e0e..00000000 --- a/tests/test_parallelism.py +++ /dev/null @@ -1,405 +0,0 @@ -#!/usr/bin/env python -# Copyright (c) twinkle authors. All rights reserved. -"""Test different parallelism strategies for Megatron backend in local mode. - -This script tests various parallelism configurations: -- TP (Tensor Parallel) -- PP (Pipeline Parallel) -- DP (Data Parallel) -- CP (Context Parallel) -- SP (Sequence Parallel, enabled when TP > 1) -- Combined configurations - -Uses Qwen2.5-0.5B-Instruct for faster testing. - -Usage: - python cookbook/megatron/test_parallelism.py -""" -import os -import subprocess -import sys -import time -from dataclasses import dataclass -from typing import List, Optional - -# Color codes for terminal output -GREEN = '\033[92m' -RED = '\033[91m' -YELLOW = '\033[93m' -RESET = '\033[0m' - - -@dataclass -class TestConfig: - """Test configuration for a parallelism strategy.""" - name: str - tp_size: int - pp_size: int - cp_size: int = 1 - sp_enabled: bool = False # Sequence Parallel - num_gpus: int = 0 # 0 means auto-calculate - max_steps: int = 2 - expected_to_pass: bool = True - notes: str = "" - - def __post_init__(self): - if self.num_gpus == 0: - self.num_gpus = self.tp_size * self.pp_size * self.cp_size - # Ensure at least 1 for DP - if self.num_gpus == 0: - self.num_gpus = 1 - - -# Test configurations -TEST_CONFIGS: List[TestConfig] = [ - # Basic single-GPU - TestConfig( - name="Single GPU (TP=1, PP=1)", - tp_size=1, pp_size=1, cp_size=1, - num_gpus=1, - notes="Baseline test" - ), - - # Tensor Parallel only - TestConfig( - name="TP=2 (Tensor Parallel)", - tp_size=2, pp_size=1, cp_size=1, - notes="Tests tensor sharding" - ), - TestConfig( - name="TP=4 (Tensor Parallel)", - tp_size=4, pp_size=1, cp_size=1, - notes="Larger TP" - ), - - # Pipeline Parallel only - TestConfig( - name="PP=2 (Pipeline Parallel)", - tp_size=1, pp_size=2, cp_size=1, - notes="Tests pipeline stages" - ), - TestConfig( - name="PP=4 (Pipeline Parallel)", - tp_size=1, pp_size=4, cp_size=1, - notes="More pipeline stages" - ), - - # TP + PP combinations - TestConfig( - name="TP=2, PP=2", - tp_size=2, pp_size=2, cp_size=1, - notes="Combined TP+PP" - ), - TestConfig( - name="TP=2, PP=4", - tp_size=2, pp_size=4, cp_size=1, - num_gpus=8, - notes="8-GPU TP+PP" - ), - - # Data Parallel (DP > 1) - TestConfig( - name="TP=2, PP=2, DP=2 (8 GPUs)", - tp_size=2, pp_size=2, cp_size=1, - num_gpus=8, - expected_to_pass=False, - notes="Known issue: P2P deadlock with DP > 1" - ), - - # Context Parallel - TestConfig( - name="CP=2 (Context Parallel)", - tp_size=1, pp_size=1, cp_size=2, - expected_to_pass=False, - notes="Known issue: CP communication deadlock" - ), - TestConfig( - name="TP=2, PP=2, CP=2 (8 GPUs)", - tp_size=2, pp_size=2, cp_size=2, - num_gpus=8, - expected_to_pass=False, - notes="Known issue: CP + PP deadlock" - ), - - # Sequence Parallel (with TP) - TestConfig( - name="TP=2 + SP (Sequence Parallel)", - tp_size=2, pp_size=1, cp_size=1, - sp_enabled=True, - notes="SP is typically enabled with TP" - ), - TestConfig( - name="TP=2, PP=2 + SP", - tp_size=2, pp_size=2, cp_size=1, - sp_enabled=True, - notes="Combined TP+PP+SP" - ), -] - - -def get_available_gpus() -> int: - """Get number of available GPUs.""" - try: - result = subprocess.run( - ['nvidia-smi', '--query-gpu=name', '--format=csv,noheader'], - capture_output=True, text=True, timeout=10 - ) - if result.returncode == 0: - return len(result.stdout.strip().split('\n')) - except Exception: - pass - return 0 - - -def create_test_script() -> str: - """Create a minimal test script for parallelism testing.""" - script = ''' -# Minimal Megatron parallelism test script -import os -import sys -import argparse - -# Set CUDA device before any imports -import torch -LOCAL_RANK = int(os.environ.get('LOCAL_RANK', '0')) -torch.cuda.set_device(LOCAL_RANK) - -import numpy as np -import twinkle -from twinkle import DeviceMesh, DeviceGroup, Platform, get_logger -from twinkle.model import MegatronModel -from peft import LoraConfig -from torch.optim import AdamW - -logger = get_logger() - -parser = argparse.ArgumentParser() -parser.add_argument('--tp_size', type=int, default=1) -parser.add_argument('--pp_size', type=int, default=1) -parser.add_argument('--cp_size', type=int, default=1) -parser.add_argument('--sp_enabled', action='store_true') -parser.add_argument('--max_steps', type=int, default=2) -args = parser.parse_args() - -WORLD_SIZE = int(os.environ.get('WORLD_SIZE', '1')) -TP_SIZE = args.tp_size -PP_SIZE = args.pp_size -CP_SIZE = args.cp_size -DP_SIZE = WORLD_SIZE // (TP_SIZE * PP_SIZE * CP_SIZE) - -device_mesh = DeviceMesh( - device_type='cuda', - mesh=np.arange(WORLD_SIZE).reshape(DP_SIZE, CP_SIZE, PP_SIZE, TP_SIZE), - mesh_dim_names=('dp', 'cp', 'pp', 'tp'), -) - -device_group = [ - DeviceGroup(name='model', ranks=list(range(WORLD_SIZE)), - device_type=Platform.get_platform().device_prefix()) -] - -twinkle.initialize( - mode='local', - nproc_per_node=WORLD_SIZE, - groups=device_group, - global_device_mesh=device_mesh, - lazy_collect=False, -) - -# Create model with smaller Qwen2.5-0.5B -model = MegatronModel( - pretrained_model_name_or_path='ms://Qwen/Qwen2.5-0.5B-Instruct', - tensor_model_parallel_size=TP_SIZE, - pipeline_model_parallel_size=PP_SIZE, - context_parallel_size=CP_SIZE, - sequence_parallel=args.sp_enabled, - mixed_precision='bf16', - recompute_granularity='full', -) - -# Add LoRA -lora_config = LoraConfig(target_modules='all-linear', r=4) -model.add_adapter_to_model('lora', lora_config, gradient_accumulation_steps=1) -model.set_optimizer(AdamW, lr=1e-4, adapter_name='lora') - -logger.info(f"Model initialized: TP={TP_SIZE}, PP={PP_SIZE}, CP={CP_SIZE}, DP={DP_SIZE}, SP={args.sp_enabled}") - -# Training loop with dummy data -for step in range(args.max_steps): - batch = { - 'input_ids': torch.randint(0, 1000, (1, 64), device=f'cuda:{LOCAL_RANK}'), - 'attention_mask': torch.ones(1, 64, device=f'cuda:{LOCAL_RANK}'), - 'labels': torch.randint(0, 1000, (1, 64), device=f'cuda:{LOCAL_RANK}'), - } - loss = model.forward_backward(inputs=batch, adapter_name='lora') - logger.info(f"Step {step}, loss: {loss}") - model.step(adapter_name='lora') - model.zero_grad(adapter_name='lora') - -logger.info("Training completed successfully!") - -# Cleanup -import torch.distributed as dist -if dist.is_initialized(): - dist.barrier() - from megatron.core import parallel_state as mpu - if mpu.is_initialized(): - mpu.destroy_model_parallel() - dist.destroy_process_group() -''' - return script - - -def run_test(config: TestConfig, available_gpus: int, test_script_path: str) -> dict: - """Run a single test configuration.""" - result = { - 'name': config.name, - 'config': f"TP={config.tp_size}, PP={config.pp_size}, CP={config.cp_size}" + - (", SP=True" if config.sp_enabled else ""), - 'gpus': config.num_gpus, - 'status': 'SKIPPED', - 'message': '', - 'duration': 0, - } - - # Check if we have enough GPUs - if config.num_gpus > available_gpus: - result['status'] = 'SKIPPED' - result['message'] = f"Need {config.num_gpus} GPUs, only {available_gpus} available" - return result - - # Build command - cuda_devices = ','.join(str(i) for i in range(config.num_gpus)) - cmd = [ - sys.executable, '-m', 'torch.distributed.run', - '--nproc_per_node', str(config.num_gpus), - test_script_path, - '--tp_size', str(config.tp_size), - '--pp_size', str(config.pp_size), - '--cp_size', str(config.cp_size), - '--max_steps', str(config.max_steps), - ] - if config.sp_enabled: - cmd.append('--sp_enabled') - - env = os.environ.copy() - env['CUDA_VISIBLE_DEVICES'] = cuda_devices - env['MEGATRON_LM_PATH'] = os.environ.get('MEGATRON_LM_PATH', '/mnt/nas2/hujinghan.hjh/Megatron-LM') - env['PYTHONPATH'] = f"{env['MEGATRON_LM_PATH']}:{os.getcwd()}/src:{env.get('PYTHONPATH', '')}" - - # Timeout: 3 minutes per test - timeout = 180 - - start_time = time.time() - try: - proc = subprocess.run( - cmd, env=env, capture_output=True, text=True, timeout=timeout - ) - duration = time.time() - start_time - result['duration'] = duration - - if proc.returncode == 0: - # Check if training completed - if 'Training completed successfully!' in proc.stdout or 'Training completed successfully!' in proc.stderr: - result['status'] = 'PASSED' - result['message'] = f"Completed in {duration:.1f}s" - else: - result['status'] = 'FAILED' - result['message'] = "Training did not complete" - else: - result['status'] = 'FAILED' - # Extract error message - stderr = proc.stderr[-500:] if len(proc.stderr) > 500 else proc.stderr - result['message'] = f"Exit code {proc.returncode}: {stderr}" - - except subprocess.TimeoutExpired: - result['status'] = 'TIMEOUT' - result['message'] = f"Exceeded {timeout}s timeout (likely deadlock)" - result['duration'] = timeout - # Kill any remaining processes - subprocess.run(['pkill', '-f', test_script_path], capture_output=True) - time.sleep(2) - except Exception as e: - result['status'] = 'ERROR' - result['message'] = str(e) - - return result - - -def main(): - print("=" * 80) - print("Megatron Parallelism Test Suite") - print("=" * 80) - - available_gpus = get_available_gpus() - print(f"Available GPUs: {available_gpus}") - - if available_gpus == 0: - print(f"{RED}No GPUs available!{RESET}") - return 1 - - # Create test script - test_script_path = '/tmp/megatron_parallelism_test.py' - with open(test_script_path, 'w') as f: - f.write(create_test_script()) - print(f"Test script created: {test_script_path}") - - # Run tests - results = [] - for i, config in enumerate(TEST_CONFIGS): - print(f"\n[{i+1}/{len(TEST_CONFIGS)}] Testing: {config.name}") - print(f" Config: TP={config.tp_size}, PP={config.pp_size}, CP={config.cp_size}, GPUs={config.num_gpus}") - if config.notes: - print(f" Notes: {config.notes}") - - result = run_test(config, available_gpus, test_script_path) - results.append(result) - - # Print result - if result['status'] == 'PASSED': - status_str = f"{GREEN}PASSED{RESET}" - elif result['status'] == 'SKIPPED': - status_str = f"{YELLOW}SKIPPED{RESET}" - else: - status_str = f"{RED}{result['status']}{RESET}" - - print(f" Result: {status_str}") - if result['message']: - print(f" Message: {result['message'][:200]}") - - # Check if result matches expectation - if config.expected_to_pass and result['status'] not in ['PASSED', 'SKIPPED']: - print(f" {RED}UNEXPECTED FAILURE (was expected to pass){RESET}") - elif not config.expected_to_pass and result['status'] == 'PASSED': - print(f" {GREEN}UNEXPECTED SUCCESS (was expected to fail){RESET}") - - # Summary - print("\n" + "=" * 80) - print("Test Summary") - print("=" * 80) - - passed = sum(1 for r in results if r['status'] == 'PASSED') - failed = sum(1 for r in results if r['status'] in ['FAILED', 'TIMEOUT', 'ERROR']) - skipped = sum(1 for r in results if r['status'] == 'SKIPPED') - - print(f"{GREEN}PASSED: {passed}{RESET}") - print(f"{RED}FAILED/TIMEOUT: {failed}{RESET}") - print(f"{YELLOW}SKIPPED: {skipped}{RESET}") - - print("\nDetailed Results:") - print("-" * 80) - for r in results: - status = r['status'] - if status == 'PASSED': - status_color = GREEN - elif status == 'SKIPPED': - status_color = YELLOW - else: - status_color = RED - print(f" {r['name']}: {status_color}{status}{RESET}") - - return 0 if failed == 0 else 1 - - -if __name__ == '__main__': - sys.exit(main()) From 70ff0bad78109037919d470d82864a92c2568418 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 Jan 2026 22:11:34 +0800 Subject: [PATCH 05/22] clean ray --- cookbook/megatron/__init__.py | 1 - cookbook/megatron/lora.py | 14 +- cookbook/megatron/lora_ray.py | 227 ----------- docs/megatron_architecture.md | 183 --------- docs/megatron_ray_status.md | 213 ---------- src/twinkle/infra/__init__.py | 30 +- src/twinkle/infra/ray/resource_manager.py | 6 +- src/twinkle/loss/__init__.py | 5 +- .../loss/vocab_parallel_cross_entropy.py | 1 - src/twinkle/megatron/model/bridge.py | 10 - src/twinkle/megatron/worker.py | 368 ------------------ src/twinkle/model/megatron.py | 30 +- 12 files changed, 43 insertions(+), 1045 deletions(-) delete mode 100644 cookbook/megatron/lora_ray.py delete mode 100644 docs/megatron_architecture.md delete mode 100644 docs/megatron_ray_status.md delete mode 100644 src/twinkle/megatron/worker.py diff --git a/cookbook/megatron/__init__.py b/cookbook/megatron/__init__.py index 1da7257c..a0b9f9e5 100644 --- a/cookbook/megatron/__init__.py +++ b/cookbook/megatron/__init__.py @@ -1,3 +1,2 @@ # Copyright (c) twinkle authors. All rights reserved. -"""Megatron training examples for twinkle.""" diff --git a/cookbook/megatron/lora.py b/cookbook/megatron/lora.py index f5816cc6..c87e5301 100644 --- a/cookbook/megatron/lora.py +++ b/cookbook/megatron/lora.py @@ -11,7 +11,6 @@ """ import argparse import os -import sys # Parse arguments first to determine mode parser = argparse.ArgumentParser() @@ -42,7 +41,7 @@ from twinkle import get_device_placement, get_logger, DeviceMesh, DeviceGroup, Platform from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta -from twinkle.loss import VocabParallelCrossEntropyLoss +from twinkle.loss import MegatronCrossEntropyLoss from twinkle.model import MegatronModel from twinkle.processor import InputProcessor @@ -99,7 +98,7 @@ def train(): # Use smaller batch size for single GPU to avoid OOM batch_size = 2 if WORLD_SIZE == 1 else 8 - # In Ray mode, pass remote_group and device_mesh to DataLoader + # In Ray mode, pass remote_group and device_mesh if args.mode == 'ray': dataloader = DataLoader( dataset=create_dataset, @@ -107,12 +106,6 @@ def train(): remote_group=GROUP_NAME, device_mesh=device_mesh, ) - else: - dataloader = DataLoader(dataset=create_dataset, batch_size=batch_size) - - # Create model - # In Ray mode, pass remote_group and device_mesh to MegatronModel - if args.mode == 'ray': model = MegatronModel( pretrained_model_name_or_path=args.model, tensor_model_parallel_size=TP_SIZE, @@ -124,6 +117,7 @@ def train(): device_mesh=device_mesh, ) else: + dataloader = DataLoader(dataset=create_dataset, batch_size=batch_size) model = MegatronModel( pretrained_model_name_or_path=args.model, tensor_model_parallel_size=TP_SIZE, @@ -138,7 +132,7 @@ def train(): model.add_adapter_to_model(adapter_name, lora_config, gradient_accumulation_steps=16) model.set_template('Qwen3Template', adapter_name=adapter_name) model.set_processor(InputProcessor, padding_side='right', adapter_name=adapter_name) - model.set_loss(VocabParallelCrossEntropyLoss, adapter_name=adapter_name) + model.set_loss(MegatronCrossEntropyLoss, adapter_name=adapter_name) model.set_optimizer(AdamW, lr=1e-4, adapter_name=adapter_name) model.set_lr_scheduler(LinearLR, adapter_name=adapter_name) diff --git a/cookbook/megatron/lora_ray.py b/cookbook/megatron/lora_ray.py deleted file mode 100644 index 85e58b78..00000000 --- a/cookbook/megatron/lora_ray.py +++ /dev/null @@ -1,227 +0,0 @@ -#!/usr/bin/env python -# Copyright (c) twinkle authors. All rights reserved. -"""Megatron-Core LoRA training in Ray mode. - -This script uses MegatronWorkerGroup for Ray-based distributed training -with proper Megatron collective operations support. - -NOTE: PP > 1 is REQUIRED for training. PP=1 has known gradient flow issues -with PEFT/LoRA and Megatron's forward_backward_no_pipelining. - -Usage: - # TP=2, PP=2 (4 GPUs) - RECOMMENDED - python cookbook/megatron/lora_ray.py --tp_size 2 --pp_size 2 --num_gpus 4 - - # PP=4, TP=1 (4 GPUs) - python cookbook/megatron/lora_ray.py --tp_size 1 --pp_size 4 --num_gpus 4 - - # PP=2, TP=1 (2 GPUs) - python cookbook/megatron/lora_ray.py --tp_size 1 --pp_size 2 --num_gpus 2 -""" -import argparse -import os -import sys - -# Add paths -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src')) -megatron_path = os.environ.get('MEGATRON_LM_PATH', '/mnt/nas2/hujinghan.hjh/Megatron-LM') -sys.path.insert(0, megatron_path) - -import ray -import torch -import numpy as np - -from twinkle import get_logger -from twinkle.megatron.worker import MegatronWorkerGroup - -logger = get_logger() - - -def create_dataset(): - """Create and prepare the training dataset - same as local mode.""" - from twinkle.dataset import Dataset, DatasetMeta - - dataset = Dataset(dataset_meta=DatasetMeta('ms://modelscope/competition_math')) - dataset.set_template('Qwen3Template', model_id='ms://Qwen/Qwen2.5-0.5B-Instruct') - dataset.map('CompetitionMathProcessor') - dataset.encode(batched=True, load_from_cache_file=False) - return dataset - - -def collate_batch(samples, batch_size: int, max_seq_len: int = 512): - """Collate samples into a batch with padding.""" - # Take batch_size samples - samples = samples[:batch_size] - - # Get max length in batch (capped at max_seq_len) - max_len = min(max(len(s['input_ids']) for s in samples), max_seq_len) - - input_ids_list = [] - attention_mask_list = [] - labels_list = [] - - for s in samples: - ids = s['input_ids'][:max_len] - pad_len = max_len - len(ids) - - input_ids_list.append(ids + [0] * pad_len) - attention_mask_list.append([1] * len(ids) + [0] * pad_len) - - # Labels: use -100 for padding - labels = s.get('labels', ids)[:max_len] - labels_list.append(labels + [-100] * pad_len) - - return { - 'input_ids': torch.tensor(input_ids_list, dtype=torch.long), - 'attention_mask': torch.tensor(attention_mask_list, dtype=torch.long), - 'labels': torch.tensor(labels_list, dtype=torch.long), - } - - -def main(): - parser = argparse.ArgumentParser(description='Megatron LoRA training in Ray mode') - parser.add_argument('--tp_size', type=int, default=2, help='Tensor parallel size') - parser.add_argument('--pp_size', type=int, default=2, help='Pipeline parallel size (must be > 1 for training)') - parser.add_argument('--cp_size', type=int, default=1, help='Context parallel size') - parser.add_argument('--num_gpus', type=int, default=4, help='Number of GPUs') - parser.add_argument('--max_steps', type=int, default=10, help='Max training steps') - parser.add_argument('--batch_size', type=int, default=2, help='Batch size per step') - parser.add_argument('--max_seq_len', type=int, default=512, help='Max sequence length') - parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate') - parser.add_argument('--model', type=str, default='ms://Qwen/Qwen2.5-0.5B-Instruct', - help='Model path or ID') - parser.add_argument('--lora_r', type=int, default=8, help='LoRA rank') - args = parser.parse_args() - - # Validate parallelism config - expected_gpus = args.tp_size * args.pp_size * args.cp_size - if args.num_gpus < expected_gpus: - logger.error(f"Need at least {expected_gpus} GPUs for TP={args.tp_size}, " - f"PP={args.pp_size}, CP={args.cp_size}, but only {args.num_gpus} provided") - return 1 - - # Prepare dataset first (on driver, before Ray workers) - logger.info("Preparing dataset...") - dataset = create_dataset() - samples = [dataset[i] for i in range(min(len(dataset), args.max_steps * args.batch_size + 100))] - logger.info(f"Loaded {len(samples)} samples") - - # Initialize Ray - if not ray.is_initialized(): - ray.init(ignore_reinit_error=True) - - logger.info(f"Ray initialized with {args.num_gpus} GPUs") - logger.info(f"Config: TP={args.tp_size}, PP={args.pp_size}, CP={args.cp_size}") - - # Create worker group - worker_group = MegatronWorkerGroup( - world_size=args.num_gpus, - tp_size=args.tp_size, - pp_size=args.pp_size, - cp_size=args.cp_size, - ) - - try: - # Initialize workers - logger.info("Initializing workers...") - results = worker_group.init_all() - if not all(results): - raise RuntimeError("Worker initialization failed") - - # Create model - logger.info(f"Loading model: {args.model}") - results = worker_group.create_model_all( - pretrained_model_name_or_path=args.model, - mixed_precision='bf16', - recompute_granularity='full', - ) - if not all(results): - raise RuntimeError("Model creation failed") - - # Add LoRA with Megatron layer names - logger.info("Adding LoRA adapters...") - lora_config = { - 'target_modules': ['linear_qkv', 'linear_proj', 'linear_fc1', 'linear_fc2'], - 'r': args.lora_r, - 'lora_alpha': args.lora_r, - 'lora_dropout': 0.0, - } - results = worker_group.add_lora_all(lora_config) - if not all(results): - raise RuntimeError("LoRA addition failed") - - # Set optimizer - logger.info(f"Setting optimizer with lr={args.lr}") - results = worker_group.set_optimizer_all(lr=args.lr) - if not all(results): - raise RuntimeError("Optimizer setup failed") - - # Training loop - logger.info(f"Starting training for {args.max_steps} steps...") - losses = [] - - # Use same batch for all steps to verify loss decreases (overfitting test) - fixed_batch = collate_batch(samples[:args.batch_size], args.batch_size, args.max_seq_len) - - for step in range(args.max_steps): - batch = fixed_batch - - # Forward-backward - step_losses = worker_group.forward_backward_all(batch) - - # Get valid loss (non-zero from last PP stage) - valid_losses = [l for l in step_losses if l > 0] - avg_loss = np.mean(valid_losses) if valid_losses else 0.0 - losses.append(avg_loss) - - logger.info(f"Step {step:3d}/{args.max_steps}, loss: {avg_loss:.4f}") - - # Optimizer step - worker_group.step_all() - - # Check loss trend - logger.info("=" * 60) - logger.info("Training Summary:") - logger.info(f" Initial loss: {losses[0]:.4f}") - logger.info(f" Final loss: {losses[-1]:.4f}") - logger.info(f" Loss change: {losses[-1] - losses[0]:.4f}") - - # Validation checks (aligned with local mode expectations) - initial_ok = losses[0] < 3 # Real data should have initial loss < 3 - decreasing = losses[-1] < losses[0] # Should decrease over training - - if initial_ok: - logger.info("✓ Initial loss is reasonable (< 3)") - else: - logger.warning(f"✗ Initial loss {losses[0]:.4f} is too high (expected < 3)") - - if decreasing: - logger.info("✓ Loss is decreasing (training is working)") - else: - logger.warning("✗ Loss is not decreasing") - - logger.info("=" * 60) - logger.info("Training completed!") - - return 0 if (initial_ok and decreasing) else 1 - - except Exception as e: - logger.error(f"Error: {e}") - import traceback - traceback.print_exc() - return 1 - - finally: - logger.info("Cleaning up...") - try: - worker_group.cleanup_all() - except Exception: - pass - try: - worker_group.shutdown() - except Exception: - pass - - -if __name__ == '__main__': - sys.exit(main()) diff --git a/docs/megatron_architecture.md b/docs/megatron_architecture.md deleted file mode 100644 index 667b5c0b..00000000 --- a/docs/megatron_architecture.md +++ /dev/null @@ -1,183 +0,0 @@ -# Twinkle Megatron 组件架构 - -## 整体代码结构 - -``` -twinkle/src/twinkle/ -├── model/ -│ ├── megatron.py # MegatronModel 主类(对外接口) -│ └── strategy/ -│ └── megatron.py # MegatronStrategy 策略类 -└── megatron/ # Megatron-Core 集成模块 - ├── __init__.py # 公共 API 导出 - ├── utils.py # 工具函数和配置映射 - ├── worker.py # [已废弃] Ray Worker 类 - ├── tuners/ - │ ├── __init__.py - │ └── lora.py # LoRA 并行线性层实现 - └── model/ - ├── __init__.py - ├── bridge.py # HF ↔ Megatron 权重转换桥 - ├── initializer.py # 模型初始化器 - └── qwen3.py # Qwen3 模型支持 -``` - -## 核心组件详解 - -### 1. MegatronModel (`model/megatron.py`) - -**作用**:对外暴露的主要接口类,封装了 Megatron 模型的完整训练流程。 - -**关键特性**: -- 使用 `@remote_class(execute='all')` 装饰器,支持 Ray 分布式 -- 提供与 `TransformersModel` 类似的 API -- 支持 TP/PP/CP/EP 多种并行策略 -- 集成 PEFT/LoRA 微调 - -**核心方法**: -```python -# 初始化 -model = MegatronModel( - pretrained_model_name_or_path='Qwen/Qwen2.5-7B-Instruct', - tensor_model_parallel_size=2, - pipeline_model_parallel_size=2, -) - -# 添加 LoRA -model.add_adapter_to_model('lora', LoraConfig(...)) - -# 训练循环 -output = model.forward_backward(inputs=batch, adapter_name='lora') -model.step(adapter_name='lora') -``` - -### 2. MegatronStrategy (`model/strategy/megatron.py`) - -**作用**:管理 Megatron 分布式并行状态的策略类。 - -**关键特性**: -- 封装 `mpu.initialize_model_parallel()` 调用 -- 支持 local (torchrun) 和 Ray 两种执行模式 -- 自动检测 `TWINKLE_MODE` 环境变量 -- 提供 TP/PP/DP/CP/EP 进程组访问 - -**初始化流程**: -```python -# Local 模式(torchrun) -# - 环境变量由 torchrun 设置 -# - dist.init_process_group 使用默认值 - -# Ray 模式 -# - 读取 RayHelper 设置的 RANK, WORLD_SIZE, MASTER_ADDR 等 -# - 显式传递给 dist.init_process_group -``` - -### 3. TwinkleBridgeInitializer (`megatron/model/bridge.py`) - -**作用**:HuggingFace 到 Megatron 的模型初始化和权重转换。 - -**关键特性**: -- 自动转换 HF config 到 Megatron TransformerConfig -- 支持流式加载大模型权重(避免 OOM) -- 处理 TP/PP 权重分片 -- 支持 MoE 模型 - -**核心流程**: -``` -HuggingFace Model - ↓ (AutoConfig.from_pretrained) -HF Config - ↓ (convert_hf_config) -Megatron TransformerConfig - ↓ (GPTModel) -Megatron Model - ↓ (TwinkleBridgeAdapter.load_weights) -Loaded Megatron Model -``` - -### 4. MegatronModelInitializer (`megatron/model/initializer.py`) - -**作用**:替代方案的模型初始化器(非 bridge 模式)。 - -**与 Bridge 的区别**: -- Bridge:使用 `initialize_megatron` 初始化 Megatron 环境 -- Initializer:假设 Megatron 已经初始化,只负责模型创建 - -**推荐使用 Bridge 模式**(`use_megatron_bridge=True`,默认)。 - -### 5. LoraParallelLinear (`megatron/tuners/lora.py`) - -**作用**:为 Megatron 的并行线性层提供 LoRA 支持。 - -**关键特性**: -- 适配 TransformerEngine 的 TELinear、TEColumnParallelLinear 等 -- 保持 TP 兼容性 -- 支持 `dispatch_megatron` 注册到 PEFT - -### 6. 工具函数 (`megatron/utils.py`) - -**关键函数**: -- `convert_hf_config`: HF config → Megatron config -- `find_all_linears`: 查找模型中所有线性层 -- `set_linear_is_expert`: 标记 MoE expert 层 -- `prepare_lora_model`: 准备 LoRA 模型 -- `TenantProcessGroupManager`: 多租户进程组管理 - -## 数据流 - -### Local 模式 (torchrun) - -``` -torchrun --nproc_per_node=4 lora.py - ↓ -每个进程独立执行 - ↓ -MegatronModel.__init__ - ↓ (use_megatron_bridge=True) -TwinkleBridgeInitializer._initialize_megatron - ↓ (检测 TWINKLE_MODE='local') -dist.init_process_group(backend='nccl') # 使用 torchrun 环境变量 - ↓ -mpu.initialize_model_parallel(tp_size, pp_size, ...) - ↓ -模型创建和权重加载 -``` - -### Ray 模式 - -``` -python lora.py --mode ray - ↓ -twinkle.initialize(mode='ray') - ↓ -RayHelper.create_workers() # 设置 RANK, WORLD_SIZE, MASTER_ADDR 等环境变量 - ↓ -MegatronModel.__init__ (在 Ray actor 内) - ↓ (use_megatron_bridge=True) -TwinkleBridgeInitializer._initialize_megatron - ↓ (检测 TWINKLE_MODE='ray') -读取环境变量 RANK, WORLD_SIZE, MASTER_ADDR, MASTER_PORT - ↓ -dist.init_process_group( - backend='nccl', - init_method='tcp://...', - rank=rank, - world_size=world_size -) - ↓ -mpu.initialize_model_parallel(tp_size, pp_size, ...) - ↓ -模型创建和权重加载 -``` - -## 废弃组件 - -### worker.py (MegatronWorker / MegatronWorkerGroup) - -**状态**:已废弃 - -**原因**: -这是之前为了解决 Ray + Megatron 集成问题的临时方案。现在已经在核心代码中正确处理了分布式初始化,所以不再需要这个独立的 Worker 类。 - -**替代方案**: -直接使用 `MegatronModel` + `@remote_class` + `remote_group` 参数。 diff --git a/docs/megatron_ray_status.md b/docs/megatron_ray_status.md deleted file mode 100644 index aaf50929..00000000 --- a/docs/megatron_ray_status.md +++ /dev/null @@ -1,213 +0,0 @@ -# Megatron + Ray 模式开发状态 - -## 最新更新 - -已修改 `TwinkleBridgeInitializer._initialize_megatron()` 和 `MegatronStrategy.initialize()`, -使其能够自动检测 `TWINKLE_MODE` 环境变量,并在 Ray 模式下正确初始化分布式环境。 - -**关键修改**: -1. `bridge.py`: `_initialize_megatron()` 检测 Ray 模式并使用正确的环境变量初始化 -2. `megatron.py` (strategy): `initialize()` 同样支持 Ray 模式 -3. `lora.py`: 统一的 demo,通过 `--mode ray` 切换 - -## Twinkle Ray 架构解析 - -### RayHelper 环境变量传递 - -在 `ray_helper.py` 第 247-268 行,`RayHelper.create_workers` **已经正确设置**了所有必要的环境变量: - -```python -env_vars.update({ - 'WORLD_SIZE': str(world_size), - 'RANK': str(rank), - 'LOCAL_RANK': str(0), - 'MASTER_ADDR': ip, - 'MASTER_PORT': str(port), - 'TWINKLE_MODE': 'ray', - ... -}) -runtime_env = RuntimeEnv(env_vars=env_vars) -worker = worker_cls.options(runtime_env=runtime_env, ...).remote(*args, **kwargs) -``` - -**每个 Ray actor 都有正确隔离的环境变量**。 - -### 为什么之前没有工作? - -问题不在环境变量传递,而在于 **`TwinkleBridgeInitializer` 没有读取这些环境变量**。 - -之前的代码: -```python -if not dist.is_initialized(): - dist.init_process_group(backend='nccl') # 只用默认值! -``` - -修复后: -```python -if twinkle_mode == 'ray': - rank = int(os.environ.get('RANK', '0')) - world_size = int(os.environ.get('WORLD_SIZE', '1')) - master_addr = os.environ.get('MASTER_ADDR', 'localhost') - master_port = os.environ.get('MASTER_PORT', '29500') - - dist.init_process_group( - backend='nccl', - init_method=f'tcp://{master_addr}:{master_port}', - rank=rank, - world_size=world_size, - ) -``` - -## 当前实现总结 - -### MegatronWorker 类(临时方案) - -之前实现了独立的 `MegatronWorkerGroup` 和 `MegatronWorker` 类作为临时方案。 -现在已修改核心代码,理论上不再需要这个类。 - -### PP=1 梯度问题(仍存在) - -**问题描述**: -- PEFT/LoRA 包装的模型在 Megatron 的 `forward_backward_no_pipelining` 下 -- 模型输出 `logits.requires_grad=False` -- 导致 `backward_step` 被跳过 - -**根本原因**: -- Megatron 的 `backward_step` 在第 464 行检查:`if output_tensor[0].requires_grad:` -- 如果为 False,backward 被跳过 -- PEFT 模型的 forward 可能破坏了梯度追踪 - -**临时解决方案**:使用 PP > 1 - -### 为什么 Ray demo 与 Local demo 分开? - -当前有两个独立的 demo 文件: -- `lora.py`: Local 模式,使用 torchrun 启动 -- `lora_ray.py`: Ray 模式,使用 MegatronWorkerGroup - -**分开的原因**: - -1. **启动方式不同** - - Local: `torchrun --nproc_per_node=N` 启动多个进程 - - Ray: `python` 启动单个 driver,创建 Ray actors - -2. **环境变量来源不同** - - Local: torchrun 自动设置 `LOCAL_RANK`, `WORLD_SIZE` 等 - - Ray: 需要在 actor 内部手动设置 - -3. **分布式初始化不同** - - Local: 在脚本开头调用 `twinkle.initialize(mode='local')` - - Ray: 每个 worker 需要单独初始化 `torch.distributed` 和 `mpu` - -4. **代码入口不同** - - Local: `lora.py` 直接在 main 中调用 train() - - Ray: `lora_ray.py` 创建 worker group,然后调用 worker 方法 - ---- - -## TODO 列表 - -### 高优先级 - -1. **[ ] 修复 PP=1 时的梯度问题** - - 问题:PEFT 模型 forward 输出 `requires_grad=False` - - 可能的方案: - - a) 在 `MegatronModel.forward_backward` 中使用手动 backward - - b) 修改 PEFT 集成方式,确保梯度正确流动 - - c) 找出 Megatron/TE 哪里断开了梯度 - -2. **[ ] 统一 local 和 ray 模式的 demo** - - 目标:单一 `lora.py` 同时支持两种模式 - - 需要修改 `MegatronModel` 或 `MegatronStrategy` 来处理 Ray 环境下的初始化 - - 参考 `twinkle/cookbook/sft/lora.py` 的统一设计 - -3. **[ ] 将 MegatronWorker 逻辑集成到 Twinkle 核心架构** - - 选项 A:修改 `MegatronStrategy` 添加 Ray 模式支持 - - 选项 B:修改 `MegatronModel` 在 `@remote_class` 下正确初始化 - - 选项 C:创建 `MegatronRayStrategy` 专门处理 Ray + Megatron - -### 中优先级 - -4. **[ ] 支持 DP > 1 的 Ray 模式** - - 当前只测试了 TP+PP 组合 - - DP 需要正确的梯度同步(`finalize_model_grads`) - -5. **[ ] 支持 CP > 1 的 Ray 模式** - - Context Parallel 需要正确的序列分割和 loss 聚合 - -6. **[ ] 移除 MegatronWorkerGroup 的 hardcode** - - 当前 worker.py 中有很多硬编码逻辑 - - 应该复用 `TwinkleBridgeInitializer` 的配置 - -### 低优先级 - -7. **[ ] 添加 checkpoint 保存/加载支持** - - 当前 Ray 模式没有实现 `model.save()` - -8. **[ ] 性能优化** - - 减少 Ray object 传输开销 - - 优化 batch 分发 - -9. **[ ] 错误处理和恢复** - - Worker 失败时的处理 - - 分布式 barrier 超时处理 - ---- - -## 统一 demo 的设计方案 - -要实现 `lora.py` 同时支持 local 和 ray 模式,需要: - -```python -# 方案:在 MegatronModel/MegatronStrategy 中检测模式并初始化 - -# 1. 修改 twinkle.initialize 添加 Megatron 专用参数 -twinkle.initialize( - mode='ray', # 或 'local' - megatron_config={ - 'tp_size': 2, - 'pp_size': 2, - # ... - } -) - -# 2. MegatronModel 在 __init__ 中检测模式 -class MegatronModel: - def __init__(self, ...): - if twinkle.get_mode() == 'ray': - # 在 actor 内部初始化分布式 - self._init_ray_distributed() - else: - # 使用 torchrun 已设置的分布式 - self._init_local_distributed() - -# 3. RayHelper 在创建 actors 时设置环境变量 -# 类似当前 MegatronWorkerGroup 的做法 -``` - -**当前阻碍**: -1. `@remote_class` 不支持在 actor 创建时传递自定义参数(如 rank, world_size) -2. `MegatronModel.__init__` 在 driver 和 worker 中都会被调用,需要区分 - ---- - -## 测试状态 - -| 配置 | GPUs | Local Mode | Ray Mode | -|------|------|------------|----------| -| TP=2, PP=2 | 4 | ✅ | ✅ | -| TP=1, PP=4 | 4 | ✅ | ✅ | -| TP=2, PP=1 | 2 | ✅ | ❌ (梯度问题) | -| TP=1, PP=2 | 2 | ✅ | 未测试 | -| DP=2, TP=2, PP=2 | 8 | ⚠️ | 未测试 | -| CP > 1 | - | ⚠️ | 未测试 | - ---- - -## 文件清单 - -- `twinkle/src/twinkle/megatron/worker.py`: MegatronWorker 和 MegatronWorkerGroup 实现 -- `twinkle/cookbook/megatron/lora.py`: Local 模式 demo -- `twinkle/cookbook/megatron/lora_ray.py`: Ray 模式 demo(暂时保留) -- `twinkle/src/twinkle/model/megatron.py`: MegatronModel 核心实现 -- `twinkle/src/twinkle/megatron/model/bridge.py`: Megatron 模型初始化 bridge diff --git a/src/twinkle/infra/__init__.py b/src/twinkle/infra/__init__.py index bd364b3f..11dc6cd1 100644 --- a/src/twinkle/infra/__init__.py +++ b/src/twinkle/infra/__init__.py @@ -446,7 +446,8 @@ def __next__(_self): def remote_function(dispatch: Union[Literal['slice', 'all'], Callable] = 'slice', execute: Literal['first', 'peer', 'all'] = 'all', - collect: Union[Literal['none', 'flatten', 'mean', 'sum'], Callable] = 'none'): + collect: Union[Literal['none', 'flatten', 'mean', 'sum'], Callable] = 'none', + sync: bool = False): """Patch each method called from remote(which class should be decorated with `remote_class`) with this decorator. Args: @@ -462,6 +463,8 @@ def remote_function(dispatch: Union[Literal['slice', 'all'], Callable] = 'slice' 'none': Return as-is 'flatten': Return a flattened list Callable: A callable that handles the collection + sync: If True, use synchronous execution (execute_all_sync) instead of async. + Required for methods with NCCL collective operations (e.g., Megatron forward_backward). """ def decorator(func: Callable[..., T1]) -> Callable[..., T1]: @@ -480,16 +483,22 @@ def wrapper(self, *args, **kwargs) -> T1: args, kwargs = RayHelper.do_get_and_collect(args, kwargs) _workers_and_args = _dispatch_args(_get_workers(self._actors, execute), dispatch, execute, device_mesh, args, kwargs) - result = RayHelper.execute_all_async(func.__name__, _workers_and_args) - result_func = RayHelper.do_get_and_collect_func(_collect_func, collect, result) - lazy_collect = _lazy_collect - if hasattr(self, '_lazy_collect'): - lazy_collect = self._lazy_collect - result = result_func if lazy_collect else result_func() - if func.__name__ == '__iter__': - return self + + # Use sync execution for methods requiring NCCL synchronization + if sync: + result = RayHelper.execute_all_sync(func.__name__, _workers_and_args) + return _collect_func(collect, result) else: - return result + result = RayHelper.execute_all_async(func.__name__, _workers_and_args) + result_func = RayHelper.do_get_and_collect_func(_collect_func, collect, result) + lazy_collect = _lazy_collect + if hasattr(self, '_lazy_collect'): + lazy_collect = self._lazy_collect + result = result_func if lazy_collect else result_func() + if func.__name__ == '__iter__': + return self + else: + return result else: raise NotImplementedError(f'Unsupported mode {_mode}') @@ -497,6 +506,7 @@ def wrapper(self, *args, **kwargs) -> T1: wrapper._collect = collect wrapper._dispatch = dispatch wrapper._lazy_collect = _lazy_collect + wrapper._sync = sync return wrapper return decorator diff --git a/src/twinkle/infra/ray/resource_manager.py b/src/twinkle/infra/ray/resource_manager.py index 47f944a7..25a2fb55 100644 --- a/src/twinkle/infra/ray/resource_manager.py +++ b/src/twinkle/infra/ray/resource_manager.py @@ -99,16 +99,12 @@ def __init__(self, self.device_groups = {} ray_address = str(ray.get_runtime_context().gcs_address) - min_rank = min(all_ranks) if all_ranks else 0 for group in groups: if group.device_type != 'CPU': ranks = group.ranks local_device_groups = [] for rank in ranks: - # Normalize rank by subtracting min_rank for node calculation - normalized_rank = rank - min_rank - node_rank = normalized_rank // nproc_per_node - # Use original rank for gpu_rank to support non-zero starting ranks + node_rank = rank // nproc_per_node gpu_rank = rank % nproc_per_node local_device_groups.append( dict( diff --git a/src/twinkle/loss/__init__.py b/src/twinkle/loss/__init__.py index f3d8c387..c501225e 100644 --- a/src/twinkle/loss/__init__.py +++ b/src/twinkle/loss/__init__.py @@ -11,7 +11,7 @@ from .listwise_reranker import ListwiseRerankerLoss from .listwise_generative_reranker import ListwiseGenerativeRerankerLoss from .grpo import GRPOLoss -from .vocab_parallel_cross_entropy import VocabParallelCrossEntropyLoss, MegatronCrossEntropyLoss +from .vocab_parallel_cross_entropy import MegatronCrossEntropyLoss from .base import Loss torch_loss_mapping = { @@ -27,6 +27,5 @@ 'listwise_reranker': ListwiseRerankerLoss, 'listwise_generative_reranker': ListwiseGenerativeRerankerLoss, 'grpo': GRPOLoss, - 'vocab_parallel_cross_entropy': VocabParallelCrossEntropyLoss, 'megatron_cross_entropy': MegatronCrossEntropyLoss, -} \ No newline at end of file +} diff --git a/src/twinkle/loss/vocab_parallel_cross_entropy.py b/src/twinkle/loss/vocab_parallel_cross_entropy.py index 19deaf37..1c035c4f 100644 --- a/src/twinkle/loss/vocab_parallel_cross_entropy.py +++ b/src/twinkle/loss/vocab_parallel_cross_entropy.py @@ -43,5 +43,4 @@ def __call__(self, inputs, outputs, **kwargs): return loss -# Alias for backward compatibility MegatronCrossEntropyLoss = VocabParallelCrossEntropyLoss diff --git a/src/twinkle/megatron/model/bridge.py b/src/twinkle/megatron/model/bridge.py index 3a2ff7bd..df71d44c 100644 --- a/src/twinkle/megatron/model/bridge.py +++ b/src/twinkle/megatron/model/bridge.py @@ -1,15 +1,5 @@ # Copyright (c) twinkle authors. All rights reserved. # GPT Bridge for HuggingFace to Megatron-Core weight conversion. -"""Weight conversion bridge between HuggingFace and Megatron-Core formats. - -This module provides independent implementation for weight loading/saving, - -Supports: -- Qwen2.5 / Qwen3 model families -- PEFT/LoRA format loading and saving -- Tensor Parallel / Pipeline Parallel weight sharding -- MoE (Mixture of Experts) models -""" from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union from types import SimpleNamespace from dataclasses import dataclass, field diff --git a/src/twinkle/megatron/worker.py b/src/twinkle/megatron/worker.py deleted file mode 100644 index 1afe7818..00000000 --- a/src/twinkle/megatron/worker.py +++ /dev/null @@ -1,368 +0,0 @@ -# Copyright (c) twinkle authors. All rights reserved. -"""Megatron Worker for Ray-based distributed training. - -This module provides MegatronWorkerGroup for coordinated Ray actor-based -training with Megatron's collective operations. - -NOTE: Currently PP > 1 is required for Ray mode training with LoRA. -PP=1 has gradient flow issues that need further investigation. - -Example: - worker_group = MegatronWorkerGroup(world_size=4, tp_size=2, pp_size=2) - worker_group.init_all() - worker_group.create_model_all('Qwen/Qwen2.5-0.5B-Instruct') - worker_group.add_lora_all({'target_modules': ['linear_qkv'], 'r': 8}) - worker_group.set_optimizer_all(lr=1e-4) - - for batch in dataloader: - losses = worker_group.forward_backward_all(batch) - worker_group.step_all() -""" -import os -from typing import Any, Dict, List - -import torch - - -def get_megatron_worker_class(): - """Returns a Ray remote class for Megatron workers.""" - import ray - - @ray.remote(num_gpus=1) - class MegatronWorker: - """Ray actor for a single Megatron rank.""" - - def __init__( - self, - rank: int, - world_size: int, - master_addr: str, - master_port: int, - tp_size: int = 1, - pp_size: int = 1, - cp_size: int = 1, - ep_size: int = 1, - ): - self.rank = rank - self.world_size = world_size - self.master_addr = master_addr - self.master_port = master_port - self.tp_size = tp_size - self.pp_size = pp_size - self.cp_size = cp_size - self.ep_size = ep_size - self.model = None - self.optimizer = None - self.hf_config = None - - def _get_local_gpu_id(self) -> int: - """Get local GPU ID for this actor.""" - import ray - cvd = os.environ.get("CUDA_VISIBLE_DEVICES", None) - if cvd is None: - gpu_ids = ray.get_gpu_ids() - return int(gpu_ids[0]) if gpu_ids else 0 - else: - gpu_ids = ray.get_gpu_ids() - if gpu_ids: - return cvd.split(",").index(str(int(gpu_ids[0]))) - return 0 - - def init(self, model_config: Dict[str, Any] = None) -> bool: - """Initialize distributed and Megatron parallel state.""" - import torch.distributed as dist - from datetime import timedelta - - os.environ["MASTER_ADDR"] = self.master_addr - os.environ["MASTER_PORT"] = str(self.master_port) - os.environ["WORLD_SIZE"] = str(self.world_size) - os.environ["RANK"] = str(self.rank) - - local_rank = self._get_local_gpu_id() - os.environ["LOCAL_RANK"] = str(local_rank) - torch.cuda.set_device(local_rank) - - if not dist.is_initialized(): - dist.init_process_group(backend="nccl", timeout=timedelta(minutes=10)) - - from megatron.core import parallel_state as mpu - if not mpu.is_initialized(): - mpu.initialize_model_parallel( - tensor_model_parallel_size=self.tp_size, - pipeline_model_parallel_size=self.pp_size, - context_parallel_size=self.cp_size, - expert_model_parallel_size=self.ep_size, - ) - - from megatron.core import tensor_parallel - torch.manual_seed(42 + self.rank) - tensor_parallel.model_parallel_cuda_manual_seed(42 + self.rank) - - print(f"[Worker rank={self.rank}] Initialized TP={self.tp_size} PP={self.pp_size}") - return True - - def create_model( - self, - pretrained_model_name_or_path: str, - mixed_precision: str = 'bf16', - recompute_granularity: str = 'full', - **kwargs, - ) -> bool: - """Create Megatron model.""" - from twinkle.megatron.model.bridge import TwinkleBridgeInitializer - - dtype_map = {'fp32': torch.float32, 'fp16': torch.float16, 'bf16': torch.bfloat16} - params_dtype = dtype_map.get(mixed_precision, torch.bfloat16) - - initializer = TwinkleBridgeInitializer( - tp_size=self.tp_size, - pp_size=self.pp_size, - cp_size=self.cp_size, - ep_size=self.ep_size, - params_dtype=params_dtype, - recompute_granularity=recompute_granularity, - **kwargs, - ) - - self.model = initializer.create_model(pretrained_model_name_or_path) - self.hf_config = initializer._hf_config - print(f"[Worker rank={self.rank}] Model created") - return True - - def add_lora(self, lora_config: Dict[str, Any]) -> bool: - """Add LoRA adapter.""" - from peft import get_peft_model, LoraConfig - from peft.tuners.tuners_utils import BaseTuner - import torch.nn as nn - - # Patch for Megatron's TransformerConfig - orig_fn = BaseTuner._get_tied_target_modules - def patched_fn(self, model: nn.Module): - try: - return orig_fn(self, model) - except AttributeError: - return [] - BaseTuner._get_tied_target_modules = patched_fn - - from twinkle.megatron.utils import set_linear_is_expert - set_linear_is_expert(self.model) - - config = LoraConfig(**lora_config) - self.model = get_peft_model(self.model, config) - - # Add compatibility methods for Megatron DDP - if not hasattr(self.model, 'finish_grad_sync'): - self.model.finish_grad_sync = lambda: None - if not hasattr(self.model, 'start_grad_sync'): - self.model.start_grad_sync = lambda: None - if not hasattr(self.model, 'no_sync'): - from contextlib import nullcontext - self.model.no_sync = nullcontext - - # Create a dummy ddp_config that has necessary attributes - if not hasattr(self.model, 'ddp_config') or self.model.ddp_config is None: - class DummyDDPConfig: - use_megatron_fsdp = False - use_distributed_optimizer = False - overlap_grad_reduce = False - overlap_param_gather = False - bucket_size = None - self.model.ddp_config = DummyDDPConfig() - - trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad) - print(f"[Worker rank={self.rank}] LoRA added, trainable params={trainable}") - return True - - def set_optimizer(self, lr: float = 1e-4, **kwargs) -> bool: - """Set up optimizer.""" - from torch.optim import AdamW - trainable_params = [p for p in self.model.parameters() if p.requires_grad] - self.optimizer = AdamW(trainable_params, lr=lr, **kwargs) - print(f"[Worker rank={self.rank}] Optimizer set") - return True - - def forward_backward(self, batch: Dict[str, torch.Tensor]) -> float: - """Execute forward-backward pass.""" - from functools import partial - from megatron.core.pipeline_parallel import get_forward_backward_func - - local_rank = self._get_local_gpu_id() - batch = {k: v.cuda(local_rank) if isinstance(v, torch.Tensor) else v - for k, v in batch.items()} - - seq_length = batch['input_ids'].shape[1] - micro_batch_size = batch['input_ids'].shape[0] - - def forward_step_func(data_iterator, model): - batch = next(data_iterator) - input_ids = batch['input_ids'] - labels = batch.get('labels') - attention_mask = batch.get('attention_mask') - - position_ids = torch.arange( - input_ids.shape[1], device=input_ids.device, dtype=torch.long - ).unsqueeze(0).expand(input_ids.shape[0], -1) - - output = model( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - labels=labels, - ) - - def loss_func(labels, output): - mask = (labels != -100).float() - loss = (output.float().view(-1) * mask.view(-1)).sum() / mask.sum().clamp(min=1) - return loss, {'loss': loss.detach()} - - return output, partial(loss_func, labels) - - self.model.train() - forward_backward_func = get_forward_backward_func() - - losses = forward_backward_func( - forward_step_func=forward_step_func, - data_iterator=iter([batch]), - model=[self.model], - num_microbatches=1, - seq_length=seq_length, - micro_batch_size=micro_batch_size, - forward_only=False, - ) - - if losses and isinstance(losses[0], dict) and 'loss' in losses[0]: - return losses[0]['loss'].item() - return 0.0 - - def step(self) -> bool: - """Optimizer step.""" - if self.optimizer is None: - return False - torch.nn.utils.clip_grad_norm_( - [p for p in self.model.parameters() if p.requires_grad], 1.0 - ) - self.optimizer.step() - self.optimizer.zero_grad() - return True - - def cleanup(self) -> bool: - """Clean up resources.""" - import torch.distributed as dist - from megatron.core import parallel_state as mpu - try: - if dist.is_initialized(): - dist.barrier() - if mpu.is_initialized(): - mpu.destroy_model_parallel() - if dist.is_initialized(): - dist.destroy_process_group() - except Exception as e: - print(f"[Worker rank={self.rank}] Cleanup error: {e}") - return True - - return MegatronWorker - - -class MegatronWorkerGroup: - """Manager for coordinated Megatron Ray workers. - - Handles synchronized creation, initialization, and execution - of Megatron workers for distributed training. - - NOTE: PP > 1 is required for training with LoRA. PP=1 has gradient issues. - """ - - def __init__( - self, - world_size: int, - tp_size: int = 1, - pp_size: int = 1, - cp_size: int = 1, - ep_size: int = 1, - master_addr: str = None, - master_port: int = None, - ): - import ray - import socket - - # Warn if PP=1 (known gradient issue) - if pp_size == 1: - print("[MegatronWorkerGroup] WARNING: PP=1 has known gradient issues. " - "Training loss may not decrease. Use PP > 1 for training.") - - self.world_size = world_size - self.tp_size = tp_size - self.pp_size = pp_size - self.cp_size = cp_size - self.ep_size = ep_size - - if master_addr is None: - master_addr = ray.util.get_node_ip_address() - if master_port is None: - with socket.socket() as sock: - sock.bind(("", 0)) - master_port = sock.getsockname()[1] - - self.master_addr = master_addr - self.master_port = master_port - - MegatronWorker = get_megatron_worker_class() - self.workers = [ - MegatronWorker.remote( - rank=rank, - world_size=world_size, - master_addr=master_addr, - master_port=master_port, - tp_size=tp_size, - pp_size=pp_size, - cp_size=cp_size, - ep_size=ep_size, - ) - for rank in range(world_size) - ] - print(f"[MegatronWorkerGroup] Created {world_size} workers (TP={tp_size}, PP={pp_size})") - - def init_all(self, model_config: Dict[str, Any] = None) -> List[bool]: - """Initialize all workers.""" - import ray - return ray.get([w.init.remote(model_config) for w in self.workers]) - - def create_model_all(self, pretrained_model_name_or_path: str, **kwargs) -> List[bool]: - """Create model on all workers.""" - import ray - return ray.get([w.create_model.remote(pretrained_model_name_or_path, **kwargs) for w in self.workers]) - - def add_lora_all(self, lora_config: Dict[str, Any]) -> List[bool]: - """Add LoRA to all workers.""" - import ray - return ray.get([w.add_lora.remote(lora_config) for w in self.workers]) - - def set_optimizer_all(self, lr: float = 1e-4, **kwargs) -> List[bool]: - """Set optimizer on all workers.""" - import ray - return ray.get([w.set_optimizer.remote(lr, **kwargs) for w in self.workers]) - - def forward_backward_all(self, batch: Dict[str, torch.Tensor]) -> List[float]: - """Execute forward/backward on all workers.""" - import ray - return ray.get([w.forward_backward.remote(batch) for w in self.workers]) - - def step_all(self) -> List[bool]: - """Optimizer step on all workers.""" - import ray - return ray.get([w.step.remote() for w in self.workers]) - - def cleanup_all(self) -> List[bool]: - """Cleanup all workers.""" - import ray - return ray.get([w.cleanup.remote() for w in self.workers]) - - def shutdown(self): - """Shutdown all workers.""" - import ray - for worker in self.workers: - try: - ray.kill(worker) - except Exception: - pass - self.workers = [] diff --git a/src/twinkle/model/megatron.py b/src/twinkle/model/megatron.py index d342e4d4..281256b4 100644 --- a/src/twinkle/model/megatron.py +++ b/src/twinkle/model/megatron.py @@ -16,7 +16,7 @@ from twinkle import remote_class, remote_function, template, DeviceMesh from twinkle.data_format import InputFeature, Trajectory from twinkle.hub import HubOperation -from twinkle.loss import Loss, VocabParallelCrossEntropyLoss +from twinkle.loss import Loss, MegatronCrossEntropyLoss from twinkle.processor import InputProcessor from twinkle.template import Template from twinkle.utils.plugin import Plugin @@ -151,10 +151,9 @@ def __init__( self.model = self._create_megatron_model(model_path, load_weights, **kwargs) self._model_wrapped = False - # Use VocabParallelCrossEntropyLoss by default for Megatron # This correctly handles vocab sharding in Tensor Parallelism self.optimizer_group: Dict[str, MegatronOptimizerGroup] = { - _default_adapter_name: MegatronOptimizerGroup(loss_instance=VocabParallelCrossEntropyLoss()) + _default_adapter_name: MegatronOptimizerGroup(loss_instance=MegatronCrossEntropyLoss()) } def _load_hf_config(self, model_path: str): @@ -512,10 +511,14 @@ def backward(self, **kwargs): loss_value.backward() optimizer_config.cur_step += 1 - @remote_function(dispatch='all', collect='avg') + @remote_function(dispatch='all', collect='avg', sync=True) def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs): """Combined forward and backward pass using Megatron's scheduler. + Note: sync=True is required for Ray mode because Megatron's pipeline + parallel uses NCCL P2P communication that requires all ranks to enter + the function simultaneously. + Always uses Megatron's get_forward_backward_func() which handles: - Pipeline scheduling (1F1B, interleaved, or no-pipeline) - Communication between stages (using proper process groups for multi-tenant isolation) @@ -583,13 +586,12 @@ def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Tr labels = labels.to(torch.cuda.current_device()) def split_tensor_for_cp(tensor, dim=-1): - """Split tensor along sequence dimension for Context Parallel. + """ + Split tensor along sequence dimension for Context Parallel. With causal masking, split into 2*CP chunks and assign alternating chunks to balance workload across CP ranks. For CP rank i: chunks [i, 2*CP-1-i] - - Based on Swift's split_cp_inputs implementation. """ if tensor is None or cp_size <= 1: return tensor @@ -901,7 +903,7 @@ def lr_step(self, **kwargs): if lr_scheduler is not None: lr_scheduler.step(**kwargs) - @remote_function() + @remote_function(dispatch='all') def set_loss(self, loss_cls: Union[Type[Loss], str], **kwargs): """Set loss function. @@ -928,7 +930,7 @@ def set_loss(self, loss_cls: Union[Type[Loss], str], **kwargs): # Keep for API compatibility, but not used in forward_backward optimizer_config.loss_instance = loss_cls() - @remote_function() + @remote_function(dispatch='all') def set_optimizer(self, optimizer_cls: Union[Type[Optimizer], str], **kwargs): """Set optimizer. @@ -968,7 +970,7 @@ def _get_trainable_parameters(self, adapter_name: str = _default_adapter_name) - params[name] = param return params - @remote_function() + @remote_function(dispatch='all') def set_lr_scheduler(self, scheduler_cls: Union[Type[LRScheduler], str], **kwargs): """Set learning rate scheduler. @@ -989,7 +991,7 @@ def set_lr_scheduler(self, scheduler_cls: Union[Type[LRScheduler], str], **kwarg assert optimizer is not None, 'Set optimizer before setting lr_scheduler' optimizer_config.lr_scheduler = scheduler_cls(optimizer, **kwargs) - @remote_function() + @remote_function(dispatch='all', sync=True) def save(self, output_dir: str, **kwargs): """Save model checkpoint. @@ -1140,7 +1142,7 @@ def _get_tied_target_modules(self, model: nn.Module) -> List[str]: BaseTuner._get_tied_target_modules = _get_tied_target_modules cls._peft_patched = True - @remote_function() + @remote_function(dispatch='all', sync=True) def add_adapter_to_model( self, adapter_name: str, @@ -1250,7 +1252,7 @@ def finish_grad_sync(): if default_config.loss_instance: self.optimizer_group[adapter_name].loss_instance = default_config.loss_instance - @remote_function() + @remote_function(dispatch='all') def set_template(self, template_cls: Union[Type[template.Template], str], **kwargs): """Set template for input encoding. @@ -1268,7 +1270,7 @@ def set_template(self, template_cls: Union[Type[template.Template], str], **kwar template_cls = Plugin.load_plugin(template_cls, template.Template) optimizer_config.template = template_cls(self.model_id, **kwargs) - @remote_function() + @remote_function(dispatch='all') def set_processor(self, processor_cls: Union[Type[InputProcessor], str], **kwargs): """Set input processor. From 9085350dbffe8e3b8f549f180bd52b156ea9dcb6 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 15 Jan 2026 15:47:22 +0800 Subject: [PATCH 06/22] ep and lora ddp --- .locks/dataset.lock | 0 ...lscope_competition_math_default_train.lock | 0 cookbook/megatron/client.py | 287 ---------- cookbook/megatron/moe_lora.py | 208 +++++++ cookbook/megatron/server.py | 270 --------- src/twinkle/megatron/distributed/__init__.py | 12 + src/twinkle/megatron/distributed/lora_ddp.py | 531 ++++++++++++++++++ src/twinkle/megatron/model/bridge.py | 283 ++++++++-- src/twinkle/model/megatron.py | 506 +++++++++++++---- test_ray_configs.py | 174 ++++++ 10 files changed, 1560 insertions(+), 711 deletions(-) delete mode 100644 .locks/dataset.lock delete mode 100644 .locks/ms___modelscope_competition_math_default_train.lock delete mode 100644 cookbook/megatron/client.py create mode 100644 cookbook/megatron/moe_lora.py delete mode 100644 cookbook/megatron/server.py create mode 100644 src/twinkle/megatron/distributed/__init__.py create mode 100644 src/twinkle/megatron/distributed/lora_ddp.py create mode 100644 test_ray_configs.py diff --git a/.locks/dataset.lock b/.locks/dataset.lock deleted file mode 100644 index e69de29b..00000000 diff --git a/.locks/ms___modelscope_competition_math_default_train.lock b/.locks/ms___modelscope_competition_math_default_train.lock deleted file mode 100644 index e69de29b..00000000 diff --git a/cookbook/megatron/client.py b/cookbook/megatron/client.py deleted file mode 100644 index 862d11e3..00000000 --- a/cookbook/megatron/client.py +++ /dev/null @@ -1,287 +0,0 @@ -# Copyright (c) twinkle authors. All rights reserved. -"""Megatron LoRA training client. - -This client sends training requests to the Megatron model server. - -Usage: - # First start the server: - python cookbook/megatron/server.py --port 8000 - - # Then run the client: - python cookbook/megatron/client.py --server_url http://localhost:8000 -""" -import argparse -from typing import Any, Dict, Optional - -import requests - -from twinkle import get_logger -from twinkle.dataset import Dataset, DatasetMeta - -logger = get_logger() - - -class MegatronModelClient: - """Client for remote Megatron model training.""" - - def __init__(self, server_url: str, timeout: int = 300): - """Initialize client. - - Args: - server_url: URL of the model server. - timeout: Request timeout in seconds. - """ - self.server_url = server_url.rstrip('/') - self.timeout = timeout - - def _request(self, endpoint: str, method: str = 'POST', data: Dict = None) -> Dict: - """Send request to server. - - Args: - endpoint: API endpoint. - method: HTTP method. - data: Request data. - - Returns: - Response data. - """ - url = f'{self.server_url}/{endpoint}' - - try: - if method == 'GET': - response = requests.get(url, timeout=self.timeout) - else: - response = requests.post(url, json=data or {}, timeout=self.timeout) - - response.raise_for_status() - return response.json() - except requests.exceptions.RequestException as e: - logger.error(f'Request failed: {e}') - return {'status': 'error', 'message': str(e)} - - def health_check(self) -> bool: - """Check if server is healthy. - - Returns: - True if server is healthy. - """ - result = self._request('health', method='GET') - return result.get('status') == 'healthy' - - def initialize_model( - self, - model_name: str, - lora_config: Optional[Dict[str, Any]] = None, - ) -> Dict: - """Initialize model on server. - - Args: - model_name: HuggingFace model name or path. - lora_config: Optional LoRA configuration. - - Returns: - Server response. - """ - return self._request('initialize', data={ - 'model_name': model_name, - 'lora_config': lora_config, - }) - - def set_optimizer(self, optimizer_type: str = 'AdamW', **kwargs) -> Dict: - """Set optimizer on server. - - Args: - optimizer_type: Optimizer type name. - **kwargs: Optimizer arguments. - - Returns: - Server response. - """ - return self._request('set_optimizer', data={ - 'optimizer_type': optimizer_type, - **kwargs, - }) - - def set_lr_scheduler(self, scheduler_type: str = 'CosineAnnealingLR', **kwargs) -> Dict: - """Set learning rate scheduler on server. - - Args: - scheduler_type: Scheduler type name. - **kwargs: Scheduler arguments. - - Returns: - Server response. - """ - return self._request('set_lr_scheduler', data={ - 'scheduler_type': scheduler_type, - **kwargs, - }) - - def train_step(self, batch: Dict[str, Any]) -> Dict: - """Execute one training step. - - Args: - batch: Input batch data. - - Returns: - Server response with loss. - """ - return self._request('train_step', data={'batch': batch}) - - def save_checkpoint(self, output_path: str) -> Dict: - """Save model checkpoint. - - Args: - output_path: Path to save checkpoint. - - Returns: - Server response. - """ - return self._request('save', data={'output_path': output_path}) - - def get_train_configs(self) -> Dict: - """Get training configuration from server. - - Returns: - Training configuration. - """ - return self._request('configs', method='GET') - - -def create_dataset(args): - """Create and preprocess dataset.""" - dataset = Dataset(dataset_meta=DatasetMeta(args.dataset)) - dataset.set_template('Qwen3Template', model_id=args.model_name) - dataset.map('CompetitionMathProcessor') - dataset.encode(batched=True) - return dataset - - -def parse_args(): - parser = argparse.ArgumentParser(description='Megatron Model Client') - - # Server arguments - parser.add_argument('--server_url', type=str, default='http://localhost:8000', - help='Model server URL') - parser.add_argument('--timeout', type=int, default=300, - help='Request timeout in seconds') - - # Model arguments - parser.add_argument('--model_name', type=str, default='ms://Qwen/Qwen2.5-7B-Instruct', - help='HuggingFace model name or path') - parser.add_argument('--output_dir', type=str, default='./output/megatron_lora', - help='Output directory for checkpoints') - - # LoRA arguments - parser.add_argument('--lora_rank', type=int, default=8, - help='LoRA rank') - parser.add_argument('--lora_alpha', type=int, default=32, - help='LoRA alpha') - parser.add_argument('--lora_dropout', type=float, default=0.05, - help='LoRA dropout') - parser.add_argument('--target_modules', type=str, default='all-linear', - help='Target modules for LoRA') - - # Training arguments - parser.add_argument('--batch_size', type=int, default=4, - help='Batch size') - parser.add_argument('--learning_rate', type=float, default=1e-4, - help='Learning rate') - parser.add_argument('--max_steps', type=int, default=1000, - help='Maximum training steps') - parser.add_argument('--save_steps', type=int, default=50, - help='Checkpoint save interval') - parser.add_argument('--log_steps', type=int, default=10, - help='Logging interval') - - # Dataset arguments - parser.add_argument('--dataset', type=str, default='ms://modelscope/competition_math', - help='Dataset name') - - return parser.parse_args() - - -def main(): - args = parse_args() - - # Create client - client = MegatronModelClient( - server_url=args.server_url, - timeout=args.timeout, - ) - - # Health check - if not client.health_check(): - logger.error('Server is not available') - return - - logger.info('Server is healthy, initializing model...') - - # Initialize model with LoRA - lora_config = { - 'r': args.lora_rank, - 'lora_alpha': args.lora_alpha, - 'lora_dropout': args.lora_dropout, - 'target_modules': args.target_modules, - } - - result = client.initialize_model( - model_name=args.model_name, - lora_config=lora_config, - ) - - if result.get('status') != 'success': - logger.error(f'Failed to initialize model: {result}') - return - - logger.info('Model initialized, setting optimizer...') - - # Set optimizer and scheduler - client.set_optimizer(optimizer_type='AdamW', lr=args.learning_rate, weight_decay=0.01) - client.set_lr_scheduler(scheduler_type='CosineAnnealingLR', T_max=args.max_steps) - - # Print training configuration - configs = client.get_train_configs() - logger.info(f'Training configs: {configs}') - - # Create dataset and dataloader - logger.info('Loading dataset...') - dataset = create_dataset(args) - - # Training loop - logger.info('Starting training...') - global_step = 0 - - for step, batch in enumerate(dataset.iter(batch_size=args.batch_size)): - if global_step >= args.max_steps: - break - - # Send batch to server for training - result = client.train_step(batch) - - if result.get('status') != 'success': - logger.error(f'Training step failed: {result}') - continue - - global_step += 1 - - # Log progress - if global_step % args.log_steps == 0: - loss = result.get('loss', 'N/A') - logger.info(f'Step {global_step}, Loss: {loss}') - - # Save checkpoint - if global_step % args.save_steps == 0: - save_result = client.save_checkpoint( - f'{args.output_dir}/checkpoint-{global_step}' - ) - logger.info(f'Checkpoint saved: {save_result}') - - # Save final model - client.save_checkpoint(args.output_dir) - logger.info(f'Training completed. Model saved to {args.output_dir}') - - -if __name__ == '__main__': - main() - diff --git a/cookbook/megatron/moe_lora.py b/cookbook/megatron/moe_lora.py new file mode 100644 index 00000000..37ee3894 --- /dev/null +++ b/cookbook/megatron/moe_lora.py @@ -0,0 +1,208 @@ +# Copyright (c) twinkle authors. All rights reserved. +"""Megatron-Core MoE (Mixture of Experts) LoRA training example. + +Supports Expert Parallel (EP) training in both local (torchrun) and Ray modes. + +Usage (Local mode with EP=2): + torchrun --nproc_per_node=4 cookbook/megatron/moe_lora.py --tp_size 2 --pp_size 1 --ep_size 2 + +Usage (Ray mode with EP=2): + TRUST_REMOTE_CODE=1 python cookbook/megatron/moe_lora.py --mode ray --tp_size 2 --pp_size 1 --ep_size 2 --num_gpus 4 +""" +import argparse +import os + +# Parse arguments first to determine mode +parser = argparse.ArgumentParser() +parser.add_argument('--mode', type=str, default='local', choices=['local', 'ray']) +parser.add_argument('--tp_size', type=int, default=2) +parser.add_argument('--pp_size', type=int, default=1) +parser.add_argument('--cp_size', type=int, default=1) +parser.add_argument('--ep_size', type=int, default=2, help='Expert parallel size') +parser.add_argument('--num_gpus', type=int, default=4, help='Number of GPUs (Ray mode only)') +parser.add_argument('--max_steps', type=int, default=5) +parser.add_argument('--model', type=str, default='ms://Qwen/Qwen3-30B-A3B', + help='MoE model path. Default: Qwen3-30B-A3B (128 experts)') +parser.add_argument('--sequence_parallel', action='store_true', default=False, + help='Enable sequence parallel (auto-enabled for MoE with TP > 1)') +args = parser.parse_args() + +# Set mode in environment before importing twinkle +os.environ['TWINKLE_MODE'] = args.mode + +# CRITICAL: Set CUDA device before any CUDA imports (local mode only) +import torch +if args.mode == 'local': + LOCAL_RANK = int(os.environ.get('LOCAL_RANK', '0')) + torch.cuda.set_device(LOCAL_RANK) + +import numpy as np +from peft import LoraConfig +from torch.optim import AdamW +from torch.optim.lr_scheduler import LinearLR + +import twinkle +from twinkle import get_device_placement, get_logger, DeviceMesh, DeviceGroup, Platform +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.loss import MegatronCrossEntropyLoss +from twinkle.model import MegatronModel +from twinkle.processor import InputProcessor + +logger = get_logger() + + +def create_dataset(): + """Create dataset for MoE training.""" + dataset = Dataset(dataset_meta=DatasetMeta('ms://modelscope/competition_math')) + # Use Qwen3 template for MoE model + dataset.set_template('Qwen3Template', model_id=args.model) + dataset.map('CompetitionMathProcessor') + dataset.encode(batched=True, load_from_cache_file=False) + return dataset + + +def train(): + # Get parallelism config + TP_SIZE = args.tp_size + PP_SIZE = args.pp_size + CP_SIZE = args.cp_size + EP_SIZE = args.ep_size + + if args.mode == 'local': + WORLD_SIZE = int(os.environ.get('WORLD_SIZE', '1')) + else: + WORLD_SIZE = args.num_gpus + + # For MoE with EP: Total parallelism = TP * PP * CP * EP * DP + # EP is placed between CP and DP in Megatron's order + DP_SIZE = WORLD_SIZE // (TP_SIZE * PP_SIZE * CP_SIZE * EP_SIZE) + + if DP_SIZE < 1: + raise ValueError( + f"Not enough GPUs ({WORLD_SIZE}) for parallelism config: " + f"TP={TP_SIZE}, PP={PP_SIZE}, CP={CP_SIZE}, EP={EP_SIZE}. " + f"Required: {TP_SIZE * PP_SIZE * CP_SIZE * EP_SIZE}" + ) + + logger.info(f"Parallelism config: TP={TP_SIZE}, PP={PP_SIZE}, CP={CP_SIZE}, EP={EP_SIZE}, DP={DP_SIZE}") + + # Device mesh: Match Megatron's order "tp-cp-ep-dp-pp" from innermost to outermost + # Shape: (PP, DP, EP, CP, TP) + device_mesh = DeviceMesh( + device_type='cuda', + mesh=np.arange(WORLD_SIZE).reshape(PP_SIZE, DP_SIZE, EP_SIZE, CP_SIZE, TP_SIZE), + mesh_dim_names=('pp', 'dp', 'ep', 'cp', 'tp'), + ) + + # Device group name - used as remote_group in Ray mode + GROUP_NAME = 'model' + + device_group = [ + DeviceGroup( + name=GROUP_NAME, + ranks=list(range(WORLD_SIZE)), + device_type=Platform.get_platform().device_prefix(), + ) + ] + + twinkle.initialize( + mode=args.mode, + nproc_per_node=WORLD_SIZE, + groups=device_group, + global_device_mesh=device_mesh, + lazy_collect=False, + ) + + # Smaller batch size for MoE models (larger memory footprint) + batch_size = 2 + + # In Ray mode, pass remote_group and device_mesh + if args.mode == 'ray': + dataloader = DataLoader( + dataset=create_dataset, + batch_size=batch_size, + remote_group=GROUP_NAME, + device_mesh=device_mesh, + ) + model = MegatronModel( + pretrained_model_name_or_path=args.model, + tensor_model_parallel_size=TP_SIZE, + pipeline_model_parallel_size=PP_SIZE, + context_parallel_size=CP_SIZE, + expert_model_parallel_size=EP_SIZE, + sequence_parallel=args.sequence_parallel, + mixed_precision='bf16', + recompute_granularity='selective', + remote_group=GROUP_NAME, + device_mesh=device_mesh, + ) + else: + dataloader = DataLoader(dataset=create_dataset, batch_size=batch_size) + model = MegatronModel( + pretrained_model_name_or_path=args.model, + tensor_model_parallel_size=TP_SIZE, + pipeline_model_parallel_size=PP_SIZE, + context_parallel_size=CP_SIZE, + expert_model_parallel_size=EP_SIZE, + sequence_parallel=args.sequence_parallel, + mixed_precision='bf16', + recompute_granularity='selective', + ) + + # LoRA config - target all linear layers in MoE (including experts) + lora_config = LoraConfig( + target_modules='all-linear', + r=8, + lora_alpha=8, + lora_dropout=0.0, + ) + adapter_name = 'lora' + model.add_adapter_to_model(adapter_name, lora_config, gradient_accumulation_steps=16) + model.set_template('Qwen3Template', adapter_name=adapter_name) + model.set_processor(InputProcessor, padding_side='right', adapter_name=adapter_name) + model.set_loss(MegatronCrossEntropyLoss, adapter_name=adapter_name) + model.set_optimizer(AdamW, lr=1e-4, adapter_name=adapter_name) + model.set_lr_scheduler(LinearLR, adapter_name=adapter_name) + + logger.info(get_device_placement()) + logger.info(model.get_train_configs(adapter_name=adapter_name)) + + for step, batch in enumerate(dataloader): + output = model.forward_backward(inputs=batch, adapter_name=adapter_name) + if step % 16 == 0: + logger.info(f'Step {step // 16}, loss: {output}') + model.clip_grad_norm(1.0, adapter_name=adapter_name) + model.step(adapter_name=adapter_name) + model.zero_grad(adapter_name=adapter_name) + model.lr_step(adapter_name=adapter_name) + if step % 100 == 0: + model.save('./output/megatron_moe_lora', adapter_name=adapter_name) + # Early stop for testing + if args.max_steps and step >= args.max_steps * 16: + logger.info(f'Reached max_steps ({args.max_steps}), stopping.') + break + + logger.info('Training completed!') + + +def cleanup(): + """Clean up distributed resources.""" + import torch.distributed as dist + try: + if dist.is_initialized(): + dist.barrier() + from megatron.core import parallel_state as mpu + if mpu.is_initialized(): + mpu.destroy_model_parallel() + except Exception: + pass + if dist.is_initialized(): + dist.destroy_process_group() + + +if __name__ == '__main__': + try: + train() + finally: + cleanup() diff --git a/cookbook/megatron/server.py b/cookbook/megatron/server.py deleted file mode 100644 index 71255028..00000000 --- a/cookbook/megatron/server.py +++ /dev/null @@ -1,270 +0,0 @@ -# Copyright (c) twinkle authors. All rights reserved. -"""Megatron LoRA training server. - -This server hosts the Megatron model and handles training requests from clients. - -Usage: - python cookbook/megatron/server.py --port 8000 --tp_size 2 -""" -import argparse -from typing import Any, Dict - -import numpy as np - -import twinkle -from twinkle import get_logger, DeviceMesh, DeviceGroup, Platform -from twinkle.model import MegatronModel -from twinkle.loss import CrossEntropyLoss -from twinkle.processor import InputProcessor - -logger = get_logger() - - -class MegatronModelServer: - """Server wrapper for Megatron model.""" - - def __init__(self, args): - self.args = args - self.model = None - self.is_initialized = False - - def initialize_model(self, model_name: str, lora_config: Dict[str, Any] = None): - """Initialize the Megatron model with optional LoRA configuration. - - Args: - model_name: HuggingFace model name or path. - lora_config: Optional LoRA configuration dict. - """ - logger.info(f'Initializing model: {model_name}') - - self.model = MegatronModel( - pretrained_model_name_or_path=model_name, - tensor_model_parallel_size=self.args.tp_size, - sequence_parallel=self.args.sequence_parallel, - mixed_precision=self.args.mixed_precision, - ) - - if lora_config: - from peft import LoraConfig - config = LoraConfig(**lora_config) - self.model.add_adapter_to_model( - 'default', - config, - gradient_accumulation_steps=self.args.gradient_accumulation_steps, - ) - - self.model.set_template('Qwen3Template') - self.model.set_processor(InputProcessor, padding_side='right') - self.model.set_loss(CrossEntropyLoss) - - self.is_initialized = True - logger.info('Model initialized successfully') - - return {'status': 'success', 'message': 'Model initialized'} - - def set_optimizer(self, optimizer_type: str = 'AdamW', **kwargs): - """Set optimizer for the model.""" - if not self.is_initialized: - return {'status': 'error', 'message': 'Model not initialized'} - - from torch.optim import AdamW, SGD - optimizer_map = {'AdamW': AdamW, 'SGD': SGD} - - if optimizer_type not in optimizer_map: - return {'status': 'error', 'message': f'Unknown optimizer: {optimizer_type}'} - - self.model.set_optimizer(optimizer_map[optimizer_type], **kwargs) - return {'status': 'success', 'message': f'Optimizer {optimizer_type} set'} - - def set_lr_scheduler(self, scheduler_type: str = 'CosineAnnealingLR', **kwargs): - """Set learning rate scheduler.""" - if not self.is_initialized: - return {'status': 'error', 'message': 'Model not initialized'} - - from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, StepLR - scheduler_map = { - 'CosineAnnealingLR': CosineAnnealingLR, - 'LinearLR': LinearLR, - 'StepLR': StepLR, - } - - if scheduler_type not in scheduler_map: - return {'status': 'error', 'message': f'Unknown scheduler: {scheduler_type}'} - - self.model.set_lr_scheduler(scheduler_map[scheduler_type], **kwargs) - return {'status': 'success', 'message': f'Scheduler {scheduler_type} set'} - - def train_step(self, batch: Dict[str, Any]): - """Execute one training step. - - Args: - batch: Input batch data. - - Returns: - Training step result with loss. - """ - if not self.is_initialized: - return {'status': 'error', 'message': 'Model not initialized'} - - # Forward-backward pass - loss = self.model.forward_backward(inputs=batch) - - # Optimizer step - self.model.clip_grad_norm(self.args.max_grad_norm) - self.model.step() - self.model.zero_grad() - self.model.lr_step() - - return {'status': 'success', 'loss': float(loss) if loss else None} - - def save_checkpoint(self, output_path: str): - """Save model checkpoint. - - Args: - output_path: Path to save checkpoint. - """ - if not self.is_initialized: - return {'status': 'error', 'message': 'Model not initialized'} - - self.model.save(output_path) - return {'status': 'success', 'message': f'Checkpoint saved to {output_path}'} - - def get_train_configs(self): - """Get current training configuration.""" - if not self.is_initialized: - return {'status': 'error', 'message': 'Model not initialized'} - - return {'status': 'success', 'configs': self.model.get_train_configs()} - - -def create_device_mesh(args) -> DeviceMesh: - """Create device mesh for Megatron parallelism.""" - mesh = np.arange(args.nproc_per_node).reshape(args.dp_size, args.tp_size) - - device_mesh = DeviceMesh( - device_type='cuda', - mesh=mesh, - mesh_dim_names=('dp', 'tp'), - ) - return device_mesh - - -def create_device_group(args): - """Create device group for model placement.""" - device_group = [ - DeviceGroup( - name='model', - ranks=list(range(args.nproc_per_node)), - device_type=Platform.get_platform().device_prefix(), - ) - ] - return device_group - - -def parse_args(): - parser = argparse.ArgumentParser(description='Megatron Model Server') - - # Server arguments - parser.add_argument('--host', type=str, default='0.0.0.0', - help='Server host') - parser.add_argument('--port', type=int, default=8000, - help='Server port') - - # Parallelism arguments - parser.add_argument('--nproc_per_node', type=int, default=4, - help='Number of processes per node') - parser.add_argument('--tp_size', type=int, default=2, - help='Tensor parallel size') - parser.add_argument('--dp_size', type=int, default=2, - help='Data parallel size') - parser.add_argument('--sequence_parallel', action='store_true', - help='Enable sequence parallelism') - parser.add_argument('--mixed_precision', type=str, default='bf16', - choices=['no', 'fp16', 'bf16'], - help='Mixed precision mode') - - # Training defaults - parser.add_argument('--gradient_accumulation_steps', type=int, default=16, - help='Gradient accumulation steps') - parser.add_argument('--max_grad_norm', type=float, default=1.0, - help='Maximum gradient norm for clipping') - - return parser.parse_args() - - -def main(): - args = parse_args() - - # Initialize distributed environment - device_mesh = create_device_mesh(args) - device_group = create_device_group(args) - - twinkle.initialize( - mode='local', - nproc_per_node=args.nproc_per_node, - groups=device_group, - global_device_mesh=device_mesh, - lazy_collect=False, - ) - - # Create model server - server = MegatronModelServer(args) - - # Start HTTP server - try: - from flask import Flask, request, jsonify - except ImportError: - logger.error('Flask not installed. Install with: pip install flask') - return - - app = Flask(__name__) - - @app.route('/health', methods=['GET']) - def health(): - return jsonify({'status': 'healthy'}) - - @app.route('/initialize', methods=['POST']) - def initialize(): - data = request.json - result = server.initialize_model( - model_name=data.get('model_name'), - lora_config=data.get('lora_config'), - ) - return jsonify(result) - - @app.route('/set_optimizer', methods=['POST']) - def set_optimizer(): - data = request.json - result = server.set_optimizer(**data) - return jsonify(result) - - @app.route('/set_lr_scheduler', methods=['POST']) - def set_lr_scheduler(): - data = request.json - result = server.set_lr_scheduler(**data) - return jsonify(result) - - @app.route('/train_step', methods=['POST']) - def train_step(): - data = request.json - result = server.train_step(batch=data.get('batch', {})) - return jsonify(result) - - @app.route('/save', methods=['POST']) - def save(): - data = request.json - result = server.save_checkpoint(output_path=data.get('output_path')) - return jsonify(result) - - @app.route('/configs', methods=['GET']) - def configs(): - result = server.get_train_configs() - return jsonify(result) - - logger.info(f'Starting server on {args.host}:{args.port}') - app.run(host=args.host, port=args.port, threaded=False) - - -if __name__ == '__main__': - main() - diff --git a/src/twinkle/megatron/distributed/__init__.py b/src/twinkle/megatron/distributed/__init__.py new file mode 100644 index 00000000..9c4a411a --- /dev/null +++ b/src/twinkle/megatron/distributed/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) twinkle authors. All rights reserved. +"""Distributed training utilities for Megatron-based models.""" + +from .lora_ddp import ( + LoRADistributedDataParallel, + wrap_model_with_lora_ddp, +) + +__all__ = [ + 'LoRADistributedDataParallel', + 'wrap_model_with_lora_ddp', +] diff --git a/src/twinkle/megatron/distributed/lora_ddp.py b/src/twinkle/megatron/distributed/lora_ddp.py new file mode 100644 index 00000000..ae5a9136 --- /dev/null +++ b/src/twinkle/megatron/distributed/lora_ddp.py @@ -0,0 +1,531 @@ +# Copyright (c) twinkle authors. All rights reserved. +""" +LoRA-aware Distributed Data Parallel wrapper for Megatron models. + +This module provides a DDP wrapper that: +1. Only creates gradient buffers for LoRA parameters (trainable) +2. Supports communication-computation overlap +3. Supports multi-tenant LoRA training with separate process groups +4. Inherits from Megatron DDP to reuse optimized communication logic +""" + +import logging +from contextlib import contextmanager +from typing import Dict, List, Optional, Set, Union + +import torch +import torch.distributed as dist +import torch.nn as nn + +logger = logging.getLogger(__name__) + +try: + from megatron.core import parallel_state as mpu + from megatron.core.distributed import DistributedDataParallel as MegatronDDP + from megatron.core.distributed.distributed_data_parallel_config import DistributedDataParallelConfig + from megatron.core.distributed.param_and_grad_buffer import _ParamAndGradBuffer, partition_buckets + from megatron.core.distributed.data_parallel_base import _BaseDataParallel + from megatron.core.transformer.transformer_config import TransformerConfig + from megatron.core.process_groups_config import ProcessGroupCollection + MEGATRON_AVAILABLE = True +except ImportError: + MEGATRON_AVAILABLE = False + MegatronDDP = object + _BaseDataParallel = object + + +class LoRADistributedDataParallel(_BaseDataParallel): + """ + Distributed Data Parallel wrapper for LoRA/PEFT models. + + This class inherits from Megatron's _BaseDataParallel and implements + DDP functionality specifically for LoRA parameters. Key features: + + 1. **Selective Parameter Registration**: Only LoRA parameters (trainable) + are registered for gradient synchronization, reducing memory overhead. + + 2. **Communication-Computation Overlap**: When overlap_grad_reduce=True, + gradient all-reduce operations are overlapped with backward computation. + + 3. **Gradient Bucketing**: Parameters are grouped into buckets for efficient + communication, reducing kernel launch overhead. + + 4. **Multi-Tenant Support**: Each tenant can have its own process group + for gradient synchronization. + + 5. **Dynamic Parameter Updates**: Supports adding/removing LoRA parameters + at runtime (requires buffer rebuild). + + Args: + config: Transformer configuration. + ddp_config: DDP configuration controlling overlap, bucketing, etc. + module: The model containing LoRA layers. + disable_bucketing: If True, all parameters go into a single bucket. + lora_param_patterns: Set of patterns to identify LoRA parameters. + tenant_id: Identifier for multi-tenant scenarios. + tenant_process_group: Custom process group for this tenant. + + Example: + >>> # Create DDP wrapper for LoRA model + >>> ddp_config = DistributedDataParallelConfig( + ... overlap_grad_reduce=True, + ... bucket_size=10000000, + ... ) + >>> ddp_model = LoRADistributedDataParallel( + ... config=transformer_config, + ... ddp_config=ddp_config, + ... module=lora_model, + ... ) + >>> + >>> # Training loop + >>> for batch in dataloader: + ... output = ddp_model(batch) + ... loss = compute_loss(output) + ... loss.backward() + ... ddp_model.finish_grad_sync() # Wait for async grad sync + ... optimizer.step() + ... ddp_model.zero_grad_buffer() + """ + + # Default patterns to identify LoRA parameters + DEFAULT_LORA_PATTERNS = {'lora_A', 'lora_B', 'lora_'} + + def __init__( + self, + config: 'TransformerConfig', + ddp_config: 'DistributedDataParallelConfig', + module: nn.Module, + disable_bucketing: bool = False, + lora_param_patterns: Optional[Set[str]] = None, + tenant_id: str = 'default', + tenant_process_group: Optional[dist.ProcessGroup] = None, + ): + if not MEGATRON_AVAILABLE: + raise ImportError("Megatron-Core is required for LoRADistributedDataParallel") + + super().__init__(config=config, module=module) + + self.ddp_config = ddp_config + self.tenant_id = tenant_id + self.lora_param_patterns = lora_param_patterns or self.DEFAULT_LORA_PATTERNS + self._disable_bucketing = disable_bucketing + + # Setup process groups + self._setup_process_groups(tenant_process_group) + + # Configure bucket size + if ddp_config.bucket_size is None: + # Use smaller default for LoRA (fewer parameters) + ddp_config.bucket_size = max( + 10000000, 500000 * self.dp_group.size() + ) + if not ddp_config.overlap_grad_reduce: + ddp_config.bucket_size = None + + self.bucket_size = ddp_config.bucket_size + if disable_bucketing: + self.bucket_size = None + + # Initialize data structures + self.param_to_bucket_group = {} + self.params_with_grad = [] + self.buffers: List['_ParamAndGradBuffer'] = [] + self.bucket_groups = [] + self.expert_parallel_buffers = [] + self.expert_parallel_bucket_groups = [] + self.grad_accs = [] + + # Register LoRA parameters and hooks + self._register_lora_params() + self._register_backward_hooks() + + # Forward hooks for param gather overlap (usually not needed for LoRA) + self.use_forward_hook = ( + ddp_config.use_distributed_optimizer and ddp_config.overlap_param_gather + ) + self.remove_forward_pre_hook_handles = {} + self.overlap_param_gather_with_optimizer_step = False + + logger.info( + f"LoRADistributedDataParallel initialized for tenant '{tenant_id}' " + f"with {len(self.params_with_grad)} LoRA parameters, " + f"{len(self.bucket_groups)} bucket groups" + ) + + def _setup_process_groups(self, tenant_process_group: Optional[dist.ProcessGroup]): + """ + Setup process groups for gradient communication. + + If tenant_process_group is provided, use it for DP communication. + Otherwise, use the default Megatron parallel state groups. + """ + if tenant_process_group is not None: + # Use custom tenant process group + self.dp_group = tenant_process_group + self.dp_cp_group = tenant_process_group + self.intra_dp_cp_group = tenant_process_group + # Expert groups use defaults (MoE multi-tenant not supported yet) + try: + self.expt_dp_group = mpu.get_expert_data_parallel_group() + self.intra_expt_dp_group = mpu.get_expert_data_parallel_group( + partial_expert_data_parallel=True + ) + except: + self.expt_dp_group = None + self.intra_expt_dp_group = None + else: + # Use default Megatron process groups + self.dp_group = mpu.get_data_parallel_group( + with_context_parallel=False, partial_data_parallel=False + ) + self.dp_cp_group = mpu.get_data_parallel_group( + with_context_parallel=True, partial_data_parallel=False + ) + self.intra_dp_cp_group = mpu.get_data_parallel_group( + with_context_parallel=True, partial_data_parallel=True + ) + try: + self.expt_dp_group = mpu.get_expert_data_parallel_group() + self.intra_expt_dp_group = mpu.get_expert_data_parallel_group( + partial_expert_data_parallel=True + ) + except: + self.expt_dp_group = None + self.intra_expt_dp_group = None + + self.tp_group = mpu.get_tensor_model_parallel_group() + self.pp_group = mpu.get_pipeline_model_parallel_group() + try: + self.ep_group = mpu.get_expert_model_parallel_group() + except: + self.ep_group = None + + def _is_lora_param(self, name: str) -> bool: + """Check if a parameter is a LoRA parameter based on name patterns.""" + for pattern in self.lora_param_patterns: + if pattern in name: + return True + return False + + def _register_lora_params(self): + """ + Register LoRA parameters to gradient buffers. + + This method: + 1. Identifies LoRA parameters by name patterns + 2. Groups them by dtype + 3. Creates gradient buffers for efficient communication + 4. Sets up bucket groups for overlapped communication + """ + param_to_name = {} + lora_params = [] + + for name, param in self.module.named_parameters(): + if not param.requires_grad: + continue + + # Only process LoRA parameters + if not self._is_lora_param(name): + continue + + self.params_with_grad.append(param) + param.grad_added_to_main_grad = False + param_to_name[param] = name + lora_params.append(param) + + if not lora_params: + logger.warning( + f"No LoRA parameters found for tenant '{self.tenant_id}'. " + f"Patterns used: {self.lora_param_patterns}" + ) + return + + # Calculate gradient scaling factor + if self.config.calculate_per_token_loss: + gradient_scaling_factor = 1.0 + else: + if self.ddp_config.average_in_collective: + gradient_scaling_factor = 1.0 + else: + gradient_scaling_factor = 1.0 / self.dp_cp_group.size() + + # Group parameters by dtype + param_and_grad_dtype_to_params = {} + param_and_grad_dtype_to_indices = {} + + for param in lora_params: + param_dtype = param.dtype + grad_dtype = torch.float if self.ddp_config.grad_reduce_in_fp32 else param.dtype + + key = (param_dtype, grad_dtype) + if key not in param_and_grad_dtype_to_params: + param_and_grad_dtype_to_params[key] = [] + param_and_grad_dtype_to_indices[key] = [] + param_and_grad_dtype_to_params[key].append(param) + param_and_grad_dtype_to_indices[key].append(len(param_and_grad_dtype_to_params[key]) - 1) + + # Create gradient buffers for each dtype combination + pg_collection = ProcessGroupCollection() + pg_collection.tp = self.tp_group + pg_collection.dp_cp = self.dp_cp_group + + for (param_dtype, grad_dtype), params in param_and_grad_dtype_to_params.items(): + indices = param_and_grad_dtype_to_indices[(param_dtype, grad_dtype)] + + buffer = _ParamAndGradBuffer( + self.ddp_config, + param_dtype, + grad_dtype, + params, + self.intra_dp_cp_group, + self.bucket_size, + param_to_name, + gradient_scaling_factor, + indices, + getattr(self.ddp_config, 'nccl_ub', False), + pg_collection, + ) + self.buffers.append(buffer) + + # Create bucket groups + self.bucket_groups = partition_buckets( + self.buffers, + force_single_bucket_group=self._disable_bucketing + ) + + # Build param to bucket group mapping + for bucket_group in self.bucket_groups: + for bucket in bucket_group.buckets: + for param in bucket.params: + self.param_to_bucket_group[param] = bucket_group + + def _register_backward_hooks(self): + """ + Register backward hooks for LoRA parameters. + + These hooks: + 1. Accumulate gradients to main_grad buffer + 2. Trigger async gradient communication when a bucket is ready + """ + for param in self.params_with_grad: + if param not in self.param_to_bucket_group: + continue + + # Get gradient accumulator and register hook + param_tmp = param.expand_as(param) + grad_acc = param_tmp.grad_fn.next_functions[0][0] + grad_acc.register_hook(self._make_backward_post_hook(param)) + self.grad_accs.append(grad_acc) + + def _make_backward_post_hook(self, param: nn.Parameter): + """ + Create a backward post-hook for a parameter. + + When the parameter's gradient is computed: + 1. Accumulate it to the main_grad buffer + 2. If overlap is enabled AND this is the last microbatch, start async communication + + Note: register_grad_ready() internally checks is_last_microbatch, so we don't + need to check it here. The bucket_group will only start communication when + all params are ready AND it's the last microbatch. + """ + def hook(*unused): + if param in self.param_to_bucket_group: + # Accumulate gradient to main_grad + if param.grad is not None and not param.grad_added_to_main_grad: + param.main_grad.add_(param.grad.data) + param.grad = None + + # If overlap enabled, notify bucket that param is ready + # Note: register_grad_ready internally checks is_last_microbatch + # and only registers when processing the last microbatch + if self.ddp_config.overlap_grad_reduce: + bucket_group = self.param_to_bucket_group[param] + # Only register if this is the last microbatch + # (bucket_group.is_last_microbatch controls this) + if bucket_group.is_last_microbatch: + bucket_group.register_grad_ready(param) + + return hook + + @contextmanager + def no_sync(self): + """ + Context manager that turns off gradient synchronization. + + Use this for gradient accumulation - only sync on the last microbatch. + """ + for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups: + bucket_group.is_last_microbatch = False + try: + yield + finally: + for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups: + bucket_group.is_last_microbatch = True + + def start_grad_sync(self, *unused): + """ + Start gradient synchronization (all-reduce or reduce-scatter). + + When overlap_grad_reduce=True, this dispatches async operations. + When overlap_grad_reduce=False, this is a no-op (finish_grad_sync does sync). + """ + for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups: + bucket_group.start_grad_sync() + + def finish_grad_sync(self): + """ + Finish gradient synchronization. + + When overlap_grad_reduce=True, waits for async operations to complete. + When overlap_grad_reduce=False, performs synchronous communication. + """ + for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups: + bucket_group.finish_grad_sync() + + def scale_gradients(self, scaling_factor: float): + """Scale all gradients in buffers by the given factor.""" + for buffer in self.buffers + self.expert_parallel_buffers: + buffer.scale_gradients(scaling_factor) + + def zero_grad_buffer(self): + """ + Zero out all gradient buffers. + + Must be called at the beginning of each training iteration. + """ + for param in self.params_with_grad: + param.grad_added_to_main_grad = False + for buffer in self.buffers + self.expert_parallel_buffers: + buffer.reset() + for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups: + bucket_group.reset() + + def broadcast_params(self): + """Broadcast parameters from rank 0 to all DP ranks.""" + for param in self.params_with_grad: + dist.broadcast( + param.data, + src=dist.get_global_rank(self.dp_cp_group, 0), + group=self.dp_cp_group, + ) + + def add_lora_params(self, new_params: Dict[str, nn.Parameter]): + """ + Dynamically add LoRA parameters. + + Note: This requires rebuilding gradient buffers, which is expensive. + Use sparingly. + + Args: + new_params: Dictionary mapping parameter names to parameters. + """ + for name, param in new_params.items(): + if param.requires_grad: + self.params_with_grad.append(param) + param.grad_added_to_main_grad = False + + self._rebuild_buffers() + + def remove_lora_params(self, param_names: Set[str]): + """ + Remove LoRA parameters. + + Note: This requires rebuilding gradient buffers, which is expensive. + + Args: + param_names: Set of parameter names to remove. + """ + new_params_with_grad = [] + for param in self.params_with_grad: + # Find param name in module + for name, p in self.module.named_parameters(): + if p is param and name not in param_names: + new_params_with_grad.append(param) + break + + self.params_with_grad = new_params_with_grad + self._rebuild_buffers() + + def _rebuild_buffers(self): + """Rebuild gradient buffers after parameter changes.""" + # Clear old hooks and buffers + self.grad_accs.clear() + self.param_to_bucket_group.clear() + self.buffers.clear() + self.bucket_groups.clear() + + # Re-register + self._register_lora_params() + self._register_backward_hooks() + + logger.info( + f"Rebuilt buffers for tenant '{self.tenant_id}': " + f"{len(self.params_with_grad)} params, {len(self.bucket_groups)} bucket groups" + ) + + def get_lora_param_count(self) -> int: + """Get the number of registered LoRA parameters.""" + return len(self.params_with_grad) + + def get_lora_param_numel(self) -> int: + """Get the total number of elements in LoRA parameters.""" + return sum(p.numel() for p in self.params_with_grad) + + +def wrap_model_with_lora_ddp( + model: nn.Module, + config: 'TransformerConfig', + ddp_config: Optional['DistributedDataParallelConfig'] = None, + lora_param_patterns: Optional[Set[str]] = None, + tenant_id: str = 'default', + tenant_process_group: Optional[dist.ProcessGroup] = None, + overlap_grad_reduce: bool = True, + bucket_size: Optional[int] = None, +) -> LoRADistributedDataParallel: + """ + Convenience function to wrap a LoRA model with DDP. + + This is the recommended way to create a LoRADistributedDataParallel wrapper. + + Args: + model: Model containing LoRA layers. + config: Transformer configuration. + ddp_config: DDP configuration. If None, creates default config. + lora_param_patterns: Patterns to identify LoRA parameters. + tenant_id: Tenant identifier for multi-tenant scenarios. + tenant_process_group: Custom process group for this tenant. + overlap_grad_reduce: Enable communication-computation overlap. + bucket_size: Size of gradient buckets. None for auto. + + Returns: + LoRADistributedDataParallel wrapper. + + Example: + >>> ddp_model = wrap_model_with_lora_ddp( + ... model=lora_model, + ... config=transformer_config, + ... overlap_grad_reduce=True, + ... ) + """ + if not MEGATRON_AVAILABLE: + raise ImportError("Megatron-Core is required for wrap_model_with_lora_ddp") + + if ddp_config is None: + ddp_config = DistributedDataParallelConfig( + overlap_grad_reduce=overlap_grad_reduce, + use_distributed_optimizer=False, # LoRA params are small + bucket_size=bucket_size, + ) + + if lora_param_patterns is None: + lora_param_patterns = LoRADistributedDataParallel.DEFAULT_LORA_PATTERNS + + return LoRADistributedDataParallel( + config=config, + ddp_config=ddp_config, + module=model, + lora_param_patterns=lora_param_patterns, + tenant_id=tenant_id, + tenant_process_group=tenant_process_group, + ) diff --git a/src/twinkle/megatron/model/bridge.py b/src/twinkle/megatron/model/bridge.py index df71d44c..1d21a0bb 100644 --- a/src/twinkle/megatron/model/bridge.py +++ b/src/twinkle/megatron/model/bridge.py @@ -273,6 +273,7 @@ class BridgeConfig: vocab_size: int = 32000 padded_vocab_size: int = 32000 intermediate_size: int = 11008 + kv_channels: int = None # head_dim, if None will be computed from hidden_size // num_attention_heads # Options add_qkv_bias: bool = False @@ -328,6 +329,17 @@ def from_hf_config( else: add_qkv_bias = False + # Determine QK layernorm setting + # Qwen3 uses QK layernorm but doesn't have explicit config attribute + qk_layernorm = getattr(hf_config, 'qk_layernorm', False) or \ + getattr(hf_config, 'use_qk_norm', False) + if not qk_layernorm and model_type in ('qwen3', 'qwen3_moe'): + # Qwen3 (dense and MoE) always uses QK layernorm (q_norm, k_norm weights) + qk_layernorm = True + + # Determine kv_channels (head_dim) - Qwen3 has explicit head_dim + kv_channels = getattr(hf_config, 'head_dim', None) + return cls( tp_size=tp_size, pp_size=pp_size, @@ -342,13 +354,13 @@ def from_hf_config( intermediate_size=getattr(hf_config, 'intermediate_size', 11008), add_qkv_bias=add_qkv_bias, add_bias_linear=getattr(hf_config, 'mlp_bias', False), - qk_layernorm=getattr(hf_config, 'qk_layernorm', False) or \ - getattr(hf_config, 'use_qk_norm', False), + qk_layernorm=qk_layernorm, tie_word_embeddings=getattr(hf_config, 'tie_word_embeddings', False), num_experts=num_experts, num_experts_per_tok=num_experts_per_tok, shared_expert_intermediate_size=shared_expert_size, model_type=model_type, + kv_channels=kv_channels, # Explicit head_dim for Qwen3 ) @@ -599,7 +611,8 @@ def _load_attention(self, mg_layer, loader: SafetensorLoader, layer_idx: int): num_heads = self.config.num_attention_heads num_kv_heads = self.config.num_key_value_heads hidden_size = self.config.hidden_size - head_dim = hidden_size // num_heads + # Use kv_channels (head_dim) from config if available (for Qwen3 etc.) + head_dim = getattr(self.config, 'kv_channels', hidden_size // num_heads) heads_per_group = num_heads // num_kv_heads # Load Q, K, V weights and merge into linear_qkv @@ -607,6 +620,11 @@ def _load_attention(self, mg_layer, loader: SafetensorLoader, layer_idx: int): k_weight = loader.get_tensor(f'{prefix}k_proj.weight') v_weight = loader.get_tensor(f'{prefix}v_proj.weight') + # Infer head_dim from actual weight shapes if needed + actual_kv_dim = k_weight.shape[0] // num_kv_heads + if actual_kv_dim != head_dim: + head_dim = actual_kv_dim + # Reshape for GQA q_weight = q_weight.reshape(num_kv_heads, heads_per_group * head_dim, hidden_size) k_weight = k_weight.reshape(num_kv_heads, head_dim, hidden_size) @@ -628,9 +646,12 @@ def _load_attention(self, mg_layer, loader: SafetensorLoader, layer_idx: int): k_bias = loader.get_tensor(f'{prefix}k_proj.bias') v_bias = loader.get_tensor(f'{prefix}v_proj.bias') - q_bias = q_bias.reshape(num_kv_heads, heads_per_group * head_dim) - k_bias = k_bias.reshape(num_kv_heads, head_dim) - v_bias = v_bias.reshape(num_kv_heads, head_dim) + # Infer head_dim from actual bias shapes if needed + actual_bias_head_dim = k_bias.shape[0] // num_kv_heads + + q_bias = q_bias.reshape(num_kv_heads, heads_per_group * actual_bias_head_dim) + k_bias = k_bias.reshape(num_kv_heads, actual_bias_head_dim) + v_bias = v_bias.reshape(num_kv_heads, actual_bias_head_dim) qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).reshape(-1) self._set_weight(mg_attn.linear_qkv.bias, qkv_bias, 'linear_qkv.bias') @@ -708,11 +729,19 @@ def _load_mlp(self, mg_layer, loader: SafetensorLoader, layer_idx: int): pass def _load_moe(self, mg_layer, loader: SafetensorLoader, layer_idx: int): - """Load MoE layer weights.""" + """Load MoE layer weights. + + Handles Expert Parallel (EP) sharding - each EP rank loads only its + assigned subset of experts based on ep_rank and ep_size. + + For EP=2 with 128 experts: + - EP rank 0 loads experts 0-63 + - EP rank 1 loads experts 64-127 + """ mg_mlp = mg_layer.mlp prefix = f'{self.HF_LAYERS_PREFIX}.{layer_idx}.mlp.' - # Load router + # Load router (replicated across all ranks) try: router_key = None for key in ['gate.weight', 'router.weight', 'gate.wg.weight']: @@ -726,6 +755,18 @@ def _load_moe(self, mg_layer, loader: SafetensorLoader, layer_idx: int): router_module = deep_getattr(mg_mlp, 'router') if router_module is not None and hasattr(router_module, 'weight'): router_module.weight.data.copy_(router_weight) + + # Load expert bias if present (for sigmoid routers like Qwen3) + for bias_key in ['gate.e_score_correction_bias', 'moe_statics.e_score_correction_bias']: + full_bias_key = f'{prefix}{bias_key}' + if full_bias_key in loader: + try: + expert_bias = loader.get_tensor(full_bias_key) + if router_module is not None and hasattr(router_module, 'expert_bias'): + router_module.expert_bias.data.copy_(expert_bias) + break + except KeyError: + continue except KeyError: pass @@ -745,25 +786,113 @@ def _load_moe(self, mg_layer, loader: SafetensorLoader, layer_idx: int): break except KeyError: continue - - # Load experts + + # Load shared expert gate if present + for gate_key in ['shared_expert_gate.weight']: + full_gate_key = f'{prefix}{gate_key}' + if full_gate_key in loader: + try: + gate_weight = loader.get_tensor(full_gate_key) + shared_module = deep_getattr(mg_mlp, 'shared_experts') + if shared_module is not None and hasattr(shared_module, 'gate_weight'): + shared_module.gate_weight.data.copy_(gate_weight) + break + except KeyError: + continue + + # Load experts with EP sharding num_local_experts = self.config.num_experts // self.ep_size + start_expert_idx = self.ep_rank * num_local_experts experts_module = deep_getattr(mg_mlp, 'experts') if experts_module is not None: - for local_idx in range(num_local_experts): - global_idx = self.ep_rank * num_local_experts + local_idx + # Determine expert module type + if hasattr(experts_module, 'weight1'): + # GroupedMLP format - weights are merged: [hidden, num_experts * ffn_hidden] + # Need to collect all experts and set at once + fc1_weights = [] # gate and up weights interleaved + fc2_weights = [] # down weights - try: - gate_weight = loader.get_tensor(f'{prefix}experts.{global_idx}.gate_proj.weight') - up_weight = loader.get_tensor(f'{prefix}experts.{global_idx}.up_proj.weight') - down_weight = loader.get_tensor(f'{prefix}experts.{global_idx}.down_proj.weight') + for local_idx in range(num_local_experts): + global_idx = start_expert_idx + local_idx + try: + gate_weight = loader.get_tensor(f'{prefix}experts.{global_idx}.gate_proj.weight') + up_weight = loader.get_tensor(f'{prefix}experts.{global_idx}.up_proj.weight') + down_weight = loader.get_tensor(f'{prefix}experts.{global_idx}.down_proj.weight') + + # Stack gate and up for gated linear unit + fc1_weights.append(gate_weight) # [ffn_hidden, hidden] + fc1_weights.append(up_weight) # [ffn_hidden, hidden] + fc2_weights.append(down_weight) # [hidden, ffn_hidden] + except KeyError as e: + print(f"Warning: Missing expert {global_idx} weights: {e}") + continue + + if fc1_weights and fc2_weights: + # GroupedMLP weight1: [hidden, num_experts * 2 * ffn_hidden] (transposed) + # HF format: [num_experts * 2, ffn_hidden, hidden] + fc1_stacked = torch.cat(fc1_weights, dim=0) # [num_experts*2*ffn_hidden, hidden] + fc1_stacked = fc1_stacked.t().contiguous() # [hidden, num_experts*2*ffn_hidden] + + # GroupedMLP weight2: [num_experts * ffn_hidden, hidden] + fc2_stacked = torch.cat(fc2_weights, dim=0) # [num_experts*hidden, ffn_hidden] + + # Set weights directly + if experts_module.weight1.shape == fc1_stacked.shape: + experts_module.weight1.data.copy_(fc1_stacked) + else: + # Handle TP split + tp_rank = self.tp_rank + tp_size = self.tp_size + if tp_size > 1: + # Split along last dim for weight1 + chunk_size = fc1_stacked.shape[1] // tp_size + fc1_chunk = fc1_stacked[:, tp_rank * chunk_size:(tp_rank + 1) * chunk_size] + experts_module.weight1.data.copy_(fc1_chunk) + else: + experts_module.weight1.data.copy_(fc1_stacked) - # For grouped linear, weights are stored differently - if hasattr(experts_module, 'linear_fc1'): - # TEGroupedLinear format + if experts_module.weight2.shape == fc2_stacked.shape: + experts_module.weight2.data.copy_(fc2_stacked) + else: + # Handle TP split + tp_rank = self.tp_rank + tp_size = self.tp_size + if tp_size > 1: + # Split along first dim for weight2 + chunk_size = fc2_stacked.shape[0] // tp_size + fc2_chunk = fc2_stacked[tp_rank * chunk_size:(tp_rank + 1) * chunk_size, :] + experts_module.weight2.data.copy_(fc2_chunk) + else: + experts_module.weight2.data.copy_(fc2_stacked) + + elif hasattr(experts_module, 'local_experts'): + # SequentialMLP format with local_experts list + for local_idx in range(num_local_experts): + global_idx = start_expert_idx + local_idx + try: + gate_weight = loader.get_tensor(f'{prefix}experts.{global_idx}.gate_proj.weight') + up_weight = loader.get_tensor(f'{prefix}experts.{global_idx}.up_proj.weight') + down_weight = loader.get_tensor(f'{prefix}experts.{global_idx}.down_proj.weight') + + expert = experts_module.local_experts[local_idx] + if hasattr(expert, 'linear_fc1'): + fc1_weight = torch.stack([gate_weight, up_weight], dim=0) + self._set_weight(expert.linear_fc1.weight, fc1_weight, 'linear_fc1.weight') + self._set_weight(expert.linear_fc2.weight, down_weight, 'linear_fc2.weight') + except KeyError: + continue + + elif hasattr(experts_module, 'linear_fc1'): + # TEGroupedLinear format - weights stored as weight0, weight1, etc. + for local_idx in range(num_local_experts): + global_idx = start_expert_idx + local_idx + try: + gate_weight = loader.get_tensor(f'{prefix}experts.{global_idx}.gate_proj.weight') + up_weight = loader.get_tensor(f'{prefix}experts.{global_idx}.up_proj.weight') + down_weight = loader.get_tensor(f'{prefix}experts.{global_idx}.down_proj.weight') + fc1_weight = torch.stack([gate_weight, up_weight], dim=0) - # Set individual expert weight fc1_param = getattr(experts_module.linear_fc1, f'weight{local_idx}', None) if fc1_param is not None: self._set_weight(fc1_param, fc1_weight, 'linear_fc1.weight', is_expert=True) @@ -771,23 +900,22 @@ def _load_moe(self, mg_layer, loader: SafetensorLoader, layer_idx: int): fc2_param = getattr(experts_module.linear_fc2, f'weight{local_idx}', None) if fc2_param is not None: self._set_weight(fc2_param, down_weight, 'linear_fc2.weight', is_expert=True) - elif hasattr(experts_module, '__getitem__'): - # List of experts - expert = experts_module[local_idx] - if hasattr(expert, 'linear_fc1'): - fc1_weight = torch.stack([gate_weight, up_weight], dim=0) - self._set_weight(expert.linear_fc1.weight, fc1_weight, 'linear_fc1.weight') - self._set_weight(expert.linear_fc2.weight, down_weight, 'linear_fc2.weight') - except KeyError: - continue + except KeyError: + continue - # Load post attention layernorm + # Load post attention layernorm (pre_mlp_layernorm for MoE) ln_key = f'{self.HF_LAYERS_PREFIX}.{layer_idx}.post_attention_layernorm.weight' try: ln_weight = loader.get_tensor(ln_key) - ln_param = deep_getattr(mg_mlp, 'linear_fc1.layer_norm_weight') - if ln_param is not None: - ln_param.data.copy_(ln_weight) + # Try pre_mlp_layernorm first (used in MoE layers) + ln_module = deep_getattr(mg_layer, 'pre_mlp_layernorm') + if ln_module is not None and hasattr(ln_module, 'weight'): + ln_module.weight.data.copy_(ln_weight) + else: + # Fallback to linear_fc1.layer_norm_weight + ln_param = deep_getattr(mg_mlp, 'linear_fc1.layer_norm_weight') + if ln_param is not None: + ln_param.data.copy_(ln_weight) except KeyError: pass @@ -1312,6 +1440,7 @@ def __init__( params_dtype=None, use_cpu_initialization: bool = False, attention_backend: str = 'flash', + sequence_parallel: bool = False, recompute_granularity: Optional[str] = 'selective', recompute_modules: Optional[list] = None, recompute_method: Optional[str] = None, @@ -1328,6 +1457,7 @@ def __init__( params_dtype: Parameter dtype (default: torch.bfloat16). use_cpu_initialization: Initialize on CPU first. attention_backend: Attention backend. + sequence_parallel: Enable sequence parallelism. Required for MoE with TP > 1. recompute_granularity: Activation recomputation strategy. 'selective' (default): Only recompute core attention (memory efficient). 'full': Recompute entire transformer layer (most memory efficient). @@ -1347,6 +1477,7 @@ def __init__( self.params_dtype = params_dtype if params_dtype is not None else torch.bfloat16 self.use_cpu_initialization = use_cpu_initialization self.attention_backend = attention_backend + self.sequence_parallel = sequence_parallel self.recompute_granularity = recompute_granularity self.recompute_modules = recompute_modules or ['core_attn'] self.recompute_method = recompute_method @@ -1452,6 +1583,7 @@ def _create_model_from_config( import torch.distributed as dist from megatron.core import parallel_state as mpu from megatron.core.transformer import TransformerConfig + from megatron.core.transformer.enums import AttnBackend from megatron.core.models.gpt import GPTModel from megatron.core.models.gpt.gpt_layer_specs import ( get_gpt_layer_with_transformer_engine_spec, @@ -1505,27 +1637,96 @@ def finalize_model_grads_for_lora(model, num_tokens=None, pg_collection=None): if hasattr(model_chunk, 'finish_grad_sync'): model_chunk.finish_grad_sync() + # MoE configuration + num_experts = mg_config_dict.get('num_experts', 0) or 0 + moe_ffn_hidden_size = mg_config_dict.get('moe_ffn_hidden_size') + moe_router_topk = mg_config_dict.get('moe_router_topk', 2) or 2 + moe_shared_expert_intermediate_size = mg_config_dict.get('moe_shared_expert_intermediate_size') + + # Build MoE-related kwargs + moe_kwargs = {} + if num_experts > 0: + moe_kwargs.update({ + 'num_moe_experts': num_experts, + 'moe_router_topk': moe_router_topk, + 'moe_router_load_balancing_type': mg_config_dict.get('moe_router_load_balancing_type', 'aux_loss'), + # MoE performance optimizations (aligned with Swift defaults) + 'moe_token_dispatcher_type': mg_config_dict.get('moe_token_dispatcher_type', 'alltoall'), # 'alltoall' is more efficient than 'allgather' + 'moe_grouped_gemm': mg_config_dict.get('moe_grouped_gemm', True), # Enable for better performance (requires grouped_gemm package) + 'moe_aux_loss_coeff': mg_config_dict.get('moe_aux_loss_coeff', 0.0), # Auxiliary load balancing loss coefficient + }) + + # FFN hidden size for MoE + if moe_ffn_hidden_size: + moe_kwargs['moe_ffn_hidden_size'] = moe_ffn_hidden_size + + # Shared expert configuration + if moe_shared_expert_intermediate_size: + moe_kwargs['moe_shared_expert_intermediate_size'] = moe_shared_expert_intermediate_size + + # Router score function (sigmoid for Qwen3, softmax for others) + if mg_config_dict.get('moe_router_score_function'): + moe_kwargs['moe_router_score_function'] = mg_config_dict['moe_router_score_function'] + + # Expert bias for sigmoid router + if mg_config_dict.get('moe_router_enable_expert_bias'): + moe_kwargs['moe_router_enable_expert_bias'] = mg_config_dict['moe_router_enable_expert_bias'] + + # Sequence parallel requires TP > 1 + # Auto-enable for MoE with TP > 1 (required by Megatron) + use_sequence_parallel = self.sequence_parallel and self.tp_size > 1 + if num_experts > 0 and self.tp_size > 1 and not use_sequence_parallel: + use_sequence_parallel = True + print(f"Auto-enabling sequence_parallel for MoE with TP={self.tp_size}") + + # For MoE models, ffn_hidden_size should be moe_ffn_hidden_size if not specified + ffn_hidden_size = mg_config_dict.get('ffn_hidden_size') + if ffn_hidden_size is None: + ffn_hidden_size = moe_ffn_hidden_size or (4 * mg_config_dict['hidden_size']) + + # For models with non-standard head dimensions (like Qwen3-30B-A3B) + kv_channels = mg_config_dict.get('kv_channels') + + # Activation function for SwiGLU (required by Megatron when gated_linear_unit=True) + use_swiglu = mg_config_dict.get('swiglu', True) + activation_func = torch.nn.functional.silu if use_swiglu else torch.nn.functional.gelu + + # Enable bias_activation_fusion for SwiGLU (same as Swift) + # Note: Only works with TransformerEngine and no bias in linear layers + has_bias = not mg_config_dict.get('disable_bias_linear', True) + bias_activation_fusion = use_swiglu and not has_bias + config = TransformerConfig( num_layers=num_layers, hidden_size=mg_config_dict['hidden_size'], num_attention_heads=num_attention_heads, num_query_groups=num_query_groups, - ffn_hidden_size=mg_config_dict.get('ffn_hidden_size', 4 * mg_config_dict['hidden_size']), + kv_channels=kv_channels, + ffn_hidden_size=ffn_hidden_size, tensor_model_parallel_size=self.tp_size, pipeline_model_parallel_size=self.pp_size, context_parallel_size=self.cp_size, expert_model_parallel_size=self.ep_size, + sequence_parallel=use_sequence_parallel, params_dtype=self.params_dtype, pipeline_dtype=self.params_dtype, # Required when using pipeline parallelism use_cpu_initialization=self.use_cpu_initialization, add_qkv_bias=mg_config_dict.get('add_qkv_bias', False), add_bias_linear=not mg_config_dict.get('disable_bias_linear', True), - gated_linear_unit=mg_config_dict.get('swiglu', True), + gated_linear_unit=use_swiglu, + activation_func=activation_func, # SiLU for SwiGLU, GELU otherwise + bias_activation_fusion=bias_activation_fusion, # Fused SwiGLU for performance normalization='RMSNorm', layernorm_epsilon=mg_config_dict.get('norm_epsilon', 1e-6), qk_layernorm=mg_config_dict.get('qk_layernorm', False), hidden_dropout=0.0, attention_dropout=0.0, + # Performance optimizations + masked_softmax_fusion=True, # Fused attention softmax + bias_dropout_fusion=True, # Fused bias + dropout + apply_rope_fusion=True, # Fused RoPE application + attention_softmax_in_fp32=True, # Numerical stability + attention_backend=AttnBackend.flash, # FlashAttention for speed # Activation recomputation for memory efficiency recompute_granularity=self.recompute_granularity, recompute_modules=self.recompute_modules if self.recompute_granularity == 'selective' else None, @@ -1534,20 +1735,26 @@ def finalize_model_grads_for_lora(model, num_tokens=None, pg_collection=None): # Critical: Set finalize_model_grads_func for DP gradient synchronization # Uses custom wrapper that handles both DDP and PEFT/LoRA models finalize_model_grads_func=finalize_model_grads_for_lora, + # MoE configuration + **moe_kwargs, ) - # Get layer spec + # Save transformer config for later use (e.g., DDP wrapping) + self._transformer_config = config + + # Get layer spec - enable moe_grouped_gemm for MoE models + moe_grouped_gemm = num_experts > 0 try: layer_spec = get_gpt_layer_with_transformer_engine_spec( num_experts=mg_config_dict.get('num_experts'), - moe_grouped_gemm=False, + moe_grouped_gemm=moe_grouped_gemm, qk_layernorm=mg_config_dict.get('qk_layernorm', False), ) except Exception: from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec layer_spec = get_gpt_layer_local_spec( num_experts=mg_config_dict.get('num_experts'), - moe_grouped_gemm=False, + moe_grouped_gemm=moe_grouped_gemm, qk_layernorm=mg_config_dict.get('qk_layernorm', False), ) diff --git a/src/twinkle/model/megatron.py b/src/twinkle/model/megatron.py index 281256b4..1cb70474 100644 --- a/src/twinkle/model/megatron.py +++ b/src/twinkle/model/megatron.py @@ -54,6 +54,10 @@ class MegatronOptimizerGroup: gradient_accumulation_steps: int = 1 cur_step: int = 0 dp_group = None + # Megatron optimizer specific fields + is_megatron_optimizer: bool = False + _last_grad_norm: float = 0.0 + _last_step_success: bool = True def do_grad_sync(self, gradient_accumulation_steps: Optional[int] = None) -> bool: """Check if gradient synchronization should happen.""" @@ -227,10 +231,11 @@ def _create_megatron_model_with_bridge( params_dtype=params_dtype, use_cpu_initialization=False, attention_backend='flash', # Use flash for training performance + sequence_parallel=self.strategy.sequence_parallel, recompute_granularity=self.recompute_granularity, recompute_modules=self.recompute_modules, - recompute_method=getattr(self, 'recompute_method', None), - recompute_num_layers=getattr(self, 'recompute_num_layers', None), + recompute_method=getattr(self, "recompute_method", None), + recompute_num_layers=getattr(self, "recompute_num_layers", None), ) # Create model (this calls initialize_megatron internally) @@ -240,6 +245,9 @@ def _create_megatron_model_with_bridge( self.strategy._initialized = True self.strategy._parallel_state = mpu + # Save transformer config for DDP wrapping + self._transformer_config = getattr(self._bridge_initializer, '_transformer_config', None) + # Move to GPU model = self._move_model_to_gpu(model) @@ -512,7 +520,8 @@ def backward(self, **kwargs): optimizer_config.cur_step += 1 @remote_function(dispatch='all', collect='avg', sync=True) - def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs): + def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], + num_microbatches: int = 1, **kwargs): """Combined forward and backward pass using Megatron's scheduler. Note: sync=True is required for Ray mode because Megatron's pipeline @@ -522,14 +531,21 @@ def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Tr Always uses Megatron's get_forward_backward_func() which handles: - Pipeline scheduling (1F1B, interleaved, or no-pipeline) - Communication between stages (using proper process groups for multi-tenant isolation) - - Gradient accumulation + - Gradient accumulation across microbatches Args: - inputs: Model inputs. + inputs: Model inputs. Can be: + - A single batch dict (num_microbatches=1) + - A list of batch dicts (num_microbatches=len(inputs)) + - An iterator yielding batch dicts + num_microbatches: Number of microbatches to process in one call. + If inputs is a list, this is inferred from len(inputs). + Using num_microbatches > 1 enables Megatron's native gradient + accumulation with better memory management and compute overlap. **kwargs: Additional arguments. Returns: - Loss value. + Average loss value across all microbatches. """ from functools import partial from megatron.core.pipeline_parallel import get_forward_backward_func @@ -538,36 +554,47 @@ def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Tr optimizer_config = self.optimizer_group[adapter_name] self._lazy_wrap_model() - # Encode inputs if needed - if isinstance(inputs, dict) and 'input_ids' not in inputs: - if optimizer_config.template is not None: - inputs = optimizer_config.template.encode(inputs) - if isinstance(inputs, list) and 'input_ids' not in inputs[0]: - if optimizer_config.template is not None: - inputs = optimizer_config.template.batch_encode(inputs) - - # Process inputs - processor = optimizer_config.processor - if processor is not None: - inputs = processor(inputs) + # Handle different input formats + # 1. Single batch dict -> wrap in list + # 2. List of batches -> use as-is + # 3. Iterator -> convert to list + if isinstance(inputs, dict): + microbatch_list = [inputs] + elif hasattr(inputs, '__iter__') and not isinstance(inputs, (list, tuple)): + # Iterator - convert to list + microbatch_list = list(inputs) + else: + microbatch_list = list(inputs) + + # Infer num_microbatches from inputs if list is provided + if len(microbatch_list) > 1: + num_microbatches = len(microbatch_list) + + # Process each microbatch + processed_batches = [] + for batch in microbatch_list: + # Encode inputs if needed + if isinstance(batch, dict) and 'input_ids' not in batch: + if optimizer_config.template is not None: + batch = optimizer_config.template.encode(batch) + + # Process inputs + processor = optimizer_config.processor + if processor is not None: + batch = processor(batch) + + processed_batches.append(batch) - # Store labels before removing from inputs - labels = inputs.get('labels', None) - if 'labels' in inputs: - try: - del inputs['labels'] - except (TypeError, KeyError): - pass # Some dict-like types don't support deletion + # Get first batch for shape info (all batches should have same shape) + first_batch = processed_batches[0] # Get CP size for sequence padding and splitting cp_size = self.strategy.cp_size cp_rank = mpu.get_context_parallel_rank() if cp_size > 1 else 0 - # Get sequence length and batch size - # Note: Megatron's schedule internally divides seq_length by cp_size - # So we pass the padded full sequence length here - original_seq_length = inputs['input_ids'].shape[1] if 'input_ids' in inputs else 1 - micro_batch_size = inputs['input_ids'].shape[0] if 'input_ids' in inputs else 1 + # Get sequence length and batch size from first batch + original_seq_length = first_batch['input_ids'].shape[1] if 'input_ids' in first_batch else 1 + micro_batch_size = first_batch['input_ids'].shape[0] if 'input_ids' in first_batch else 1 # For CP > 1, pad seq_length to be divisible by 2*cp_size if cp_size > 1: @@ -579,12 +606,6 @@ def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Tr else: seq_length = original_seq_length - # Move labels to GPU if needed - if labels is not None and not isinstance(labels, torch.Tensor): - labels = torch.tensor(labels, device=torch.cuda.current_device()) - elif labels is not None: - labels = labels.to(torch.cuda.current_device()) - def split_tensor_for_cp(tensor, dim=-1): """ Split tensor along sequence dimension for Context Parallel. @@ -620,10 +641,20 @@ def split_tensor_for_cp(tensor, dim=-1): # forward_step_func(data_iterator, model) -> (output_tensor, partial(loss_func)) def forward_step_func(data_iterator, model): batch = next(data_iterator) - input_ids = batch.get('input_ids') - position_ids = batch.get('position_ids') - attention_mask = batch.get('attention_mask') - batch_labels = batch.get('labels', labels) # Use batch labels or passed labels + + # Move tensors to CUDA with non_blocking=True for async transfer + # This matches Swift's to_device(data, 'cuda', non_blocking=True) behavior + def to_cuda_non_blocking(tensor): + if tensor is None: + return None + if isinstance(tensor, torch.Tensor) and not tensor.is_cuda: + return tensor.cuda(non_blocking=True) + return tensor + + input_ids = to_cuda_non_blocking(batch.get('input_ids')) + position_ids = to_cuda_non_blocking(batch.get('position_ids')) + attention_mask = to_cuda_non_blocking(batch.get('attention_mask')) + batch_labels = to_cuda_non_blocking(batch.get('labels')) # Labels should be in each batch # Pad sequence for Context Parallel compatibility # Megatron's RoPE requires seq_len % (2 * cp_size) == 0 @@ -671,54 +702,44 @@ def forward_step_func(data_iterator, model): ) # Megatron's compute_language_model_loss returns per-token loss [batch, seq] - # We need to aggregate it with loss_mask + # We need to aggregate it with loss_mask and return 3 values for proper per-token normalization + # Swift uses 3-value return: (loss, num_tokens, loss_dict) for per-token loss mode def megatron_loss_func(labels_for_mask, cp_size, output_tensor): # output_tensor is per-token loss [batch, seq] # Create loss mask from labels (ignore -100) - loss_mask = (labels_for_mask != -100).float() + loss_mask = (labels_for_mask != -100) - # Flatten and compute mean - losses = output_tensor.float().view(-1) - loss_mask_flat = loss_mask.view(-1) + # Compute per-token losses + losses = output_tensor.float() - # Compute local sum and count - local_loss_sum = torch.sum(losses * loss_mask_flat) - local_count = loss_mask_flat.sum() + # Compute sum of losses and token count (same as Swift) + # Swift: loss = torch.cat([torch.sum(losses * loss_mask).view(1), loss_mask.sum().view(1)]) + loss_sum = torch.sum(losses * loss_mask.float()) + local_num_tokens = loss_mask.sum().to(torch.int) - # For CP > 1, aggregate loss across CP ranks - # Note: Megatron's schedules.py will multiply loss by cp_group_size - # for legacy 2-output loss_func. This assumes loss_func returns SUM/cp_size (MEAN). - # So we should return local MEAN (not global MEAN) and let Megatron handle it. + # For CP > 1, aggregate across CP ranks if cp_size > 1: - # All-reduce the count across CP ranks to get total token count - # This is needed for correct averaging - total_count = local_count.clone() + # All-reduce loss sum and token count across CP ranks + loss_tensor = torch.cat([loss_sum.view(1), local_num_tokens.float().view(1)]) torch.distributed.all_reduce( - total_count, + loss_tensor, op=torch.distributed.ReduceOp.SUM, group=mpu.get_context_parallel_group() ) - - # Return local_loss_sum / total_count - # Megatron will multiply by cp_size, so the final result is: - # (local_loss_sum / total_count) * cp_size - # = (local_loss_sum * cp_size) / total_count - # But we want: SUM(local_loss_sum) / total_count - # So we need to do all_reduce on loss_sum too - total_loss_sum = local_loss_sum.clone() - torch.distributed.all_reduce( - total_loss_sum, - op=torch.distributed.ReduceOp.SUM, - group=mpu.get_context_parallel_group() - ) - - # Return global mean, but Megatron will multiply by cp_size - # So we divide by cp_size first to counteract that - loss = (total_loss_sum / total_count.clamp(min=1)) / cp_size - else: - loss = local_loss_sum / local_count.clamp(min=1) + loss_sum = loss_tensor[0] + local_num_tokens = loss_tensor[1].to(torch.int) - return loss, {'loss': loss.detach()} + # Return 3 values for per-token loss mode (same as Swift): + # 1. loss (sum, will be divided by num_tokens by Megatron) + # 2. local_num_tokens (for proper averaging) + # 3. loss_dict for logging + reporting_loss = torch.cat([loss_sum.detach().view(1), local_num_tokens.float().view(1)]) + + return ( + loss_sum, + local_num_tokens, + {'lm loss': reporting_loss} + ) return output_tensor, partial(megatron_loss_func, batch_labels, cp_size) @@ -728,32 +749,55 @@ def megatron_loss_func(labels_for_mask, cp_size, output_tensor): # - PP = 1: forward_backward_no_pipelining forward_backward_func = get_forward_backward_func() - # Create single-item iterator - data_iter = iter([inputs]) + # Create iterator over all microbatches + # Megatron's scheduler will call next(data_iterator) num_microbatches times + data_iter = iter(processed_batches) # Run forward-backward with Megatron's scheduler # Megatron handles all communication internally using proper process groups + # With num_microbatches > 1, gradients are accumulated across microbatches losses = forward_backward_func( forward_step_func=forward_step_func, data_iterator=data_iter, model=[self.model], - num_microbatches=1, + num_microbatches=num_microbatches, seq_length=seq_length, micro_batch_size=micro_batch_size, forward_only=False, ) # Extract loss from results (only last PP stage returns non-empty) - loss = 0.0 + # With 3-value loss_func return, each loss_dict contains 'lm loss': [loss_sum, num_tokens] + # We aggregate across all microbatches using proper per-token averaging + total_loss_sum = 0.0 + total_num_tokens = 0 if losses: for loss_dict in losses: - if isinstance(loss_dict, dict) and 'loss' in loss_dict: - loss = loss_dict['loss'] - break - elif isinstance(loss_dict, torch.Tensor): - loss = loss_dict - break + if isinstance(loss_dict, dict): + # New format: 'lm loss' contains [loss_sum, num_tokens] + if 'lm loss' in loss_dict: + reporting = loss_dict['lm loss'] + if isinstance(reporting, torch.Tensor) and reporting.numel() == 2: + total_loss_sum += reporting[0].item() + total_num_tokens += int(reporting[1].item()) + elif isinstance(reporting, (list, tuple)) and len(reporting) == 2: + total_loss_sum += float(reporting[0]) + total_num_tokens += int(reporting[1]) + # Legacy format: 'loss' contains average loss + elif 'loss' in loss_dict: + loss_val = loss_dict['loss'] + if isinstance(loss_val, torch.Tensor): + total_loss_sum += loss_val.item() + else: + total_loss_sum += float(loss_val) + total_num_tokens += 1 # Fallback: treat as 1 sample + + # Compute average loss (per-token average across all microbatches) + if total_num_tokens > 0: + loss = total_loss_sum / total_num_tokens + else: + loss = total_loss_sum / max(num_microbatches, 1) # For PP > 1, broadcast loss from last PP stage to all ranks # Note: mpu is imported at module level, no need to reimport @@ -777,14 +821,9 @@ def megatron_loss_func(labels_for_mask, cp_size, output_tensor): optimizer_config.cur_step += 1 - # Critical: Synchronize all DP replicas before returning - # This ensures all DP replicas complete the same training step before - # moving to the next batch, preventing P2P communication deadlocks - dp_world_size = mpu.get_data_parallel_world_size() - if dp_world_size > 1: - # Use barrier on DP+CP group to synchronize all replicas - dp_cp_group = mpu.get_data_parallel_group(with_context_parallel=True) - dist.barrier(group=dp_cp_group) + # Note: finalize_model_grads is called inside forward_backward_func + # which already handles gradient synchronization across DP replicas. + # No additional barrier is needed here - adding one would hurt performance. if isinstance(loss, torch.Tensor): return loss.detach().cpu().float().numpy() @@ -803,6 +842,15 @@ def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type: int = 2, **kwarg Total norm of gradients. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + + # Check if using Megatron optimizer (handles clip_grad internally) + is_megatron_opt = getattr(optimizer_config, 'is_megatron_optimizer', False) + if is_megatron_opt: + # Megatron optimizer handles gradient clipping in step() + # Return the grad_norm from last step if available + return getattr(optimizer_config, '_last_grad_norm', 0.0) + parameters = self._get_trainable_parameters(adapter_name).values() return torch.nn.utils.clip_grad_norm_( @@ -844,16 +892,29 @@ def step(self, **kwargs): optimizer = optimizer_config.optimizer assert optimizer is not None, 'Set optimizer correctly before stepping' - optimizer.step(**kwargs) + # Check if using Megatron optimizer (has different step() signature) + is_megatron_opt = getattr(optimizer_config, 'is_megatron_optimizer', False) + if is_megatron_opt: + # Megatron optimizer step() returns (success, grad_norm, num_zeros) + success, grad_norm, num_zeros = optimizer.step() + # Store grad_norm for later retrieval + optimizer_config._last_grad_norm = grad_norm if grad_norm is not None else 0.0 + optimizer_config._last_step_success = success + else: + optimizer.step(**kwargs) def _is_model_ddp_wrapped(self) -> bool: """Check if model is wrapped with DDP. Returns: - True if model is wrapped with DDP (either Megatron DDP or PyTorch DDP). + True if model is wrapped with DDP (either Megatron DDP, LoRA DDP, or PyTorch DDP). """ from torch.nn.parallel import DistributedDataParallel as TorchDDP - return isinstance(self.model, (MegatronDDP, TorchDDP)) + try: + from twinkle.megatron.distributed import LoRADistributedDataParallel + return isinstance(self.model, (MegatronDDP, LoRADistributedDataParallel, TorchDDP)) + except ImportError: + return isinstance(self.model, (MegatronDDP, TorchDDP)) def _get_unwrapped_model(self) -> nn.Module: """Get the unwrapped model. @@ -862,6 +923,120 @@ def _get_unwrapped_model(self) -> nn.Module: The base model without DDP wrapper. """ return self.strategy.unwrap_model(self.model) + + @remote_function(dispatch='all') + def wrap_with_lora_ddp( + self, + adapter_name: str = _default_adapter_name, + overlap_grad_reduce: bool = True, + bucket_size: Optional[int] = None, + lora_param_patterns: Optional[set] = None, + **kwargs + ): + """ + Wrap the model with LoRA-aware DDP for efficient distributed training. + + This enables: + 1. Communication-computation overlap: Gradient all-reduce starts while + backward pass is still computing other gradients. + 2. Gradient bucketing: Small gradients are grouped for efficient communication. + 3. Async gradient reduction: Non-blocking communication operations. + + Should be called AFTER add_adapter_to_model() and BEFORE training starts. + + Args: + adapter_name: Name of the adapter (for multi-adapter scenarios). + overlap_grad_reduce: Enable communication-computation overlap. + Set to True for best performance (default). + bucket_size: Size of gradient buckets in number of elements. + None for automatic sizing based on LoRA parameter count. + lora_param_patterns: Set of patterns to identify LoRA parameters. + Default: {'lora_A', 'lora_B', 'lora_'} + **kwargs: Additional arguments passed to DDP config. + - use_distributed_optimizer: bool (default False for LoRA) + - grad_reduce_in_fp32: bool (default False) + + Returns: + self for method chaining. + + Example: + >>> model = MegatronModel(...) + >>> model.add_adapter_to_model('lora', lora_config) + >>> model.wrap_with_lora_ddp( + ... adapter_name='lora', + ... overlap_grad_reduce=True, + ... ) + >>> # Now training will use optimized DDP + >>> for batch in dataloader: + ... loss = model.forward_backward(inputs=batch) + ... model.step() + ... model.zero_grad() + """ + from twinkle.megatron.distributed import wrap_model_with_lora_ddp + from megatron.core.distributed import DistributedDataParallelConfig + + # Check if already wrapped + if self._is_model_ddp_wrapped(): + if mpu.get_data_parallel_rank() == 0: + print("Warning: Model is already DDP wrapped. Skipping wrap_with_lora_ddp().") + return self + + # Get the transformer config from the bridge initializer + transformer_config = getattr(self, '_transformer_config', None) + if transformer_config is None: + # Try to get from strategy + if hasattr(self.strategy, 'transformer_config'): + transformer_config = self.strategy.transformer_config + else: + raise ValueError( + "Cannot find TransformerConfig. " + "Make sure model is created via MegatronModel." + ) + + # Create DDP config + ddp_config = DistributedDataParallelConfig( + overlap_grad_reduce=overlap_grad_reduce, + use_distributed_optimizer=kwargs.get('use_distributed_optimizer', False), + grad_reduce_in_fp32=kwargs.get('grad_reduce_in_fp32', False), + bucket_size=bucket_size, + average_in_collective=kwargs.get('average_in_collective', False), + ) + + # Get tenant process group if multi-tenant + tenant_process_group = kwargs.get('tenant_process_group', None) + + # Wrap model + self.model = wrap_model_with_lora_ddp( + model=self.model, + config=transformer_config, + ddp_config=ddp_config, + lora_param_patterns=lora_param_patterns, + tenant_id=adapter_name, + tenant_process_group=tenant_process_group, + ) + + # CRITICAL: Update transformer_config.no_sync_func to use the DDP's no_sync + # This is needed for Megatron's forward_backward_func to properly control + # gradient synchronization during gradient accumulation + transformer_config.no_sync_func = self.model.no_sync + + # Also update finalize_model_grads_func to use the DDP's finish_grad_sync + # instead of the custom PEFT version + def finalize_model_grads_for_ddp(model_list, *args, **kwargs): + """Finalize gradients for DDP-wrapped model.""" + for model_chunk in model_list: + if hasattr(model_chunk, 'finish_grad_sync'): + model_chunk.finish_grad_sync() + transformer_config.finalize_model_grads_func = finalize_model_grads_for_ddp + + if mpu.get_data_parallel_rank() == 0: + lora_count = self.model.get_lora_param_count() + lora_numel = self.model.get_lora_param_numel() + print(f"Wrapped model with LoRA DDP: {lora_count} params, {lora_numel:,} elements") + print(f" overlap_grad_reduce={overlap_grad_reduce}") + print(f" bucket_size={bucket_size or 'auto'}") + + return self @remote_function(dispatch='all') def zero_grad(self, **kwargs): @@ -869,22 +1044,29 @@ def zero_grad(self, **kwargs): For DDP-wrapped models, also zeros the DDP gradient buffers. + Note: For DDP-wrapped models, zero_grad_buffer() is always called + because it's essential for the next training iteration. The + do_grad_sync check only affects the optimizer.zero_grad() call. + Args: **kwargs: Additional arguments. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] + # For DDP-wrapped models, ALWAYS zero the gradient buffer + # This is essential because Megatron's forward_backward_func uses + # the buffer's state to track gradient accumulation + if self._is_model_ddp_wrapped() and hasattr(self.model, 'zero_grad_buffer'): + self.model.zero_grad_buffer() + if not optimizer_config.do_grad_sync(kwargs.get('gradient_accumulation_steps')): return optimizer = optimizer_config.optimizer if optimizer is not None: - optimizer.zero_grad(**kwargs) - - # For Megatron DDP, zero the gradient buffer - if self._is_model_ddp_wrapped() and hasattr(self.model, 'zero_grad_buffer'): - self.model.zero_grad_buffer() + # Clear set_to_none for better compatibility + optimizer.zero_grad(set_to_none=True) @remote_function() def lr_step(self, **kwargs): @@ -936,11 +1118,21 @@ def set_optimizer(self, optimizer_cls: Union[Type[Optimizer], str], **kwargs): Args: optimizer_cls: Optimizer class or string name. + - Standard PyTorch optimizers: 'AdamW', 'Adam', 'SGD', etc. + - 'MegatronDistributed': Use Megatron's distributed optimizer **kwargs: Additional arguments. + - For standard optimizers: lr, weight_decay, etc. + - For MegatronDistributed: use_distributed_optimizer, clip_grad, etc. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] + # Check if requesting Megatron distributed optimizer + if optimizer_cls == 'MegatronDistributed' or kwargs.pop('use_megatron_optimizer', False): + optimizer_config.optimizer = self._create_megatron_optimizer(**kwargs) + optimizer_config.is_megatron_optimizer = True + return + if isinstance(optimizer_cls, str): if hasattr(torch.optim, optimizer_cls): optimizer_cls = getattr(torch.optim, optimizer_cls) @@ -950,6 +1142,73 @@ def set_optimizer(self, optimizer_cls: Union[Type[Optimizer], str], **kwargs): optimizer_config.optimizer = optimizer_cls( self._get_trainable_parameters(adapter_name).values(), **kwargs ) + optimizer_config.is_megatron_optimizer = False + + def _create_megatron_optimizer(self, **kwargs): + """Create Megatron distributed optimizer. + + This provides significant memory savings for large models by sharding + optimizer states across DP replicas. + + Args: + **kwargs: Optimizer configuration options. + - lr: Learning rate (default: 1e-4) + - weight_decay: Weight decay (default: 0.0) + - use_distributed_optimizer: Shard optimizer states (default: True) + - clip_grad: Gradient clipping threshold (default: 1.0) + - bf16: Use bf16 training (default: True) + - adam_beta1, adam_beta2, adam_eps: Adam parameters + + Returns: + MegatronOptimizer instance. + """ + from megatron.core.optimizer import get_megatron_optimizer, OptimizerConfig + + # Build optimizer config + lr = kwargs.get('lr', 1e-4) + use_distributed_optimizer = kwargs.get('use_distributed_optimizer', True) + + opt_config = OptimizerConfig( + optimizer='adam', + lr=lr, + min_lr=kwargs.get('min_lr', 0.0), + weight_decay=kwargs.get('weight_decay', 0.0), + adam_beta1=kwargs.get('adam_beta1', 0.9), + adam_beta2=kwargs.get('adam_beta2', 0.999), + adam_eps=kwargs.get('adam_eps', 1e-8), + clip_grad=kwargs.get('clip_grad', 1.0), + bf16=kwargs.get('bf16', True), + use_distributed_optimizer=use_distributed_optimizer, + overlap_param_gather=kwargs.get('overlap_param_gather', False), + log_num_zeros_in_grad=kwargs.get('log_num_zeros_in_grad', False), + ) + + # For PEFT models, we need to handle the case where model is not DDP-wrapped + # We create a temporary wrapper to satisfy Megatron's optimizer requirements + model_chunks = [self.model] + + # Check if model has ddp_config (required for distributed optimizer) + if not hasattr(self.model, 'ddp_config') and use_distributed_optimizer: + # For PEFT models without DDP, fall back to non-distributed optimizer + # but still use Megatron's optimized implementation + opt_config.use_distributed_optimizer = False + if mpu.get_data_parallel_rank() == 0: + print("Note: Falling back to non-distributed optimizer for PEFT model. " + "For distributed optimizer, wrap model with MegatronDDP.") + + try: + optimizer = get_megatron_optimizer( + config=opt_config, + model_chunks=model_chunks, + ) + return optimizer + except Exception as e: + # Fallback to simple FP32 optimizer if Megatron optimizer fails + if mpu.get_data_parallel_rank() == 0: + print(f"Warning: Failed to create Megatron optimizer ({e}), falling back to PyTorch AdamW") + + params = [p for p in self.model.parameters() if p.requires_grad] + return torch.optim.AdamW(params, lr=lr, weight_decay=kwargs.get('weight_decay', 0.0)) def _get_trainable_parameters(self, adapter_name: str = _default_adapter_name) -> Dict[str, nn.Parameter]: """Get trainable parameters. @@ -1215,22 +1474,37 @@ def finish_grad_sync(): """Synchronize gradients across DP ranks for non-DDP models. This is a compatibility shim for Megatron's finalize_model_grads. - For PEFT/LoRA models, we manually all-reduce gradients. + For PEFT/LoRA models, we manually all-reduce only trainable (LoRA) gradients. + + Optimizations: + 1. Only process gradients of trainable parameters (LoRA weights) + 2. Skip if DP size is 1 (no synchronization needed) + 3. Use coalesced all-reduce for efficiency """ dp_world_size = mpu.get_data_parallel_world_size() - if dp_world_size > 1: - dp_cp_group = mpu.get_data_parallel_group(with_context_parallel=True) - grads = [] - for param in self.model.parameters(): - if param.requires_grad and param.grad is not None: - grads.append(param.grad.data) - - if grads: - from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors - coalesced = _flatten_dense_tensors(grads) - dist.all_reduce(coalesced, op=dist.ReduceOp.AVG, group=dp_cp_group) - for grad, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): - grad.copy_(synced) + if dp_world_size <= 1: + return # No sync needed for DP=1 + + dp_cp_group = mpu.get_data_parallel_group(with_context_parallel=True) + grads = [] + + # Only collect gradients from trainable parameters (LoRA weights) + # This is much faster than iterating all parameters + for param in self.model.parameters(): + if param.requires_grad and param.grad is not None: + grads.append(param.grad.data) + + if not grads: + return # No gradients to sync + + # Coalesced all-reduce for efficiency + from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + coalesced = _flatten_dense_tensors(grads) + dist.all_reduce(coalesced, op=dist.ReduceOp.AVG, group=dp_cp_group) + + # Copy back synchronized gradients + for grad, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): + grad.copy_(synced) self.model.finish_grad_sync = finish_grad_sync diff --git a/test_ray_configs.py b/test_ray_configs.py new file mode 100644 index 00000000..5bf6be09 --- /dev/null +++ b/test_ray_configs.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python +"""Test script for Ray mode with various parallelism configurations. + +Records loss, memory usage, and training time. +""" +import os +import sys +import time +import subprocess +import re + +# Test configurations: (tp_size, pp_size, num_gpus, name) +CONFIGS = [ + (2, 2, 4, "TP=2_PP=2"), + (4, 1, 4, "TP=4_PP=1"), + (1, 4, 4, "TP=1_PP=4"), + (2, 1, 2, "TP=2_PP=1"), +] + +MODEL = "ms://Qwen/Qwen2.5-0.5B-Instruct" +MAX_STEPS = 5 +TIMEOUT = 600 # 10 minutes per test + +def run_test(mode, tp_size, pp_size, num_gpus, name): + """Run a single test configuration.""" + env = os.environ.copy() + env["MEGATRON_LM_PATH"] = "/mnt/nas2/hujinghan.hjh/Megatron-LM" + env["PYTHONPATH"] = "/mnt/nas2/hujinghan.hjh/Megatron-LM:/mnt/nas2/hujinghan.hjh/twinkle/src:" + env.get("PYTHONPATH", "") + env["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in range(num_gpus)) + env["TRUST_REMOTE_CODE"] = "1" + + log_file = f"/mnt/nas2/hujinghan.hjh/twinkle/test_{mode}_{name}.log" + + if mode == "ray": + cmd = [ + "/mnt/nas2/anaconda3/envs/hjh/bin/python", + "cookbook/megatron/lora.py", + "--mode", "ray", + "--tp_size", str(tp_size), + "--pp_size", str(pp_size), + "--num_gpus", str(num_gpus), + "--model", MODEL, + "--max_steps", str(MAX_STEPS), + ] + else: + # Find an available port + import socket + with socket.socket() as s: + s.bind(("", 0)) + port = s.getsockname()[1] + + cmd = [ + "/mnt/nas2/anaconda3/envs/hjh/bin/python", "-m", "torch.distributed.run", + "--nproc_per_node", str(num_gpus), + "--master_port", str(port), + "cookbook/megatron/lora.py", + "--tp_size", str(tp_size), + "--pp_size", str(pp_size), + "--model", MODEL, + "--max_steps", str(MAX_STEPS), + ] + + print(f"\n{'='*60}") + print(f"Running: {mode} mode, {name}") + print(f"Command: {' '.join(cmd)}") + print(f"Log: {log_file}") + print(f"{'='*60}") + + start_time = time.time() + + with open(log_file, "w") as f: + try: + result = subprocess.run( + cmd, + cwd="/mnt/nas2/hujinghan.hjh/twinkle", + env=env, + stdout=f, + stderr=subprocess.STDOUT, + timeout=TIMEOUT, + ) + success = result.returncode == 0 + except subprocess.TimeoutExpired: + print(f" TIMEOUT after {TIMEOUT}s") + success = False + except Exception as e: + print(f" ERROR: {e}") + success = False + + elapsed = time.time() - start_time + + # Parse results from log + losses = [] + memory = None + + with open(log_file, "r") as f: + content = f.read() + + # Extract losses + for match in re.finditer(r"Step (\d+), loss: ([\d.]+)", content): + step = int(match.group(1)) + loss = float(match.group(2)) + losses.append((step, loss)) + + # Check for completion + completed = "Training completed!" in content + + return { + "mode": mode, + "config": name, + "tp": tp_size, + "pp": pp_size, + "gpus": num_gpus, + "losses": losses, + "elapsed": elapsed, + "success": success and completed, + "log_file": log_file, + } + + +def cleanup(): + """Kill any lingering processes.""" + os.system("pkill -9 -f 'lora.py|MegatronModel|ray' 2>/dev/null") + time.sleep(5) + + +def main(): + results = [] + + for tp, pp, gpus, name in CONFIGS: + # Test Ray mode + cleanup() + ray_result = run_test("ray", tp, pp, gpus, name) + results.append(ray_result) + + # Test Local mode + cleanup() + local_result = run_test("local", tp, pp, gpus, name) + results.append(local_result) + + cleanup() + + # Print summary + print("\n" + "="*80) + print("SUMMARY") + print("="*80) + print(f"{'Mode':<8} {'Config':<15} {'GPUs':<6} {'Status':<10} {'Time(s)':<10} {'Step0 Loss':<12} {'Step5 Loss':<12}") + print("-"*80) + + for r in results: + status = "✅ OK" if r["success"] else "❌ FAIL" + step0_loss = r["losses"][0][1] if len(r["losses"]) > 0 else "N/A" + step5_loss = r["losses"][-1][1] if len(r["losses"]) > 5 else "N/A" + if isinstance(step0_loss, float): + step0_loss = f"{step0_loss:.4f}" + if isinstance(step5_loss, float): + step5_loss = f"{step5_loss:.4f}" + print(f"{r['mode']:<8} {r['config']:<15} {r['gpus']:<6} {status:<10} {r['elapsed']:<10.1f} {step0_loss:<12} {step5_loss:<12}") + + print("="*80) + + # Save results to file + with open("/mnt/nas2/hujinghan.hjh/twinkle/test_results.txt", "w") as f: + f.write("Ray Mode Parallelism Test Results\n") + f.write("="*80 + "\n\n") + for r in results: + f.write(f"Mode: {r['mode']}, Config: {r['config']}, GPUs: {r['gpus']}\n") + f.write(f"Success: {r['success']}, Time: {r['elapsed']:.1f}s\n") + f.write(f"Losses: {r['losses']}\n") + f.write(f"Log: {r['log_file']}\n") + f.write("-"*40 + "\n") + + +if __name__ == "__main__": + main() From 94e27918f15a6eeaac6077349745c703ecc3e603 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 15 Jan 2026 20:05:03 +0800 Subject: [PATCH 07/22] wip --- src/twinkle/megatron/distributed/__init__.py | 16 +- src/twinkle/megatron/distributed/lora_ddp.py | 531 ------------------ .../megatron/distributed/multi_tenant_ddp.py | 342 +++++++++++ .../megatron/model/multi_tenant_megatron.py | 517 +++++++++++++++++ src/twinkle/model/megatron.py | 13 +- src/twinkle/model/strategy/megatron.py | 1 + 6 files changed, 871 insertions(+), 549 deletions(-) delete mode 100644 src/twinkle/megatron/distributed/lora_ddp.py create mode 100644 src/twinkle/megatron/distributed/multi_tenant_ddp.py create mode 100644 src/twinkle/megatron/model/multi_tenant_megatron.py diff --git a/src/twinkle/megatron/distributed/__init__.py b/src/twinkle/megatron/distributed/__init__.py index 9c4a411a..76a7c8a0 100644 --- a/src/twinkle/megatron/distributed/__init__.py +++ b/src/twinkle/megatron/distributed/__init__.py @@ -1,12 +1,16 @@ # Copyright (c) twinkle authors. All rights reserved. -"""Distributed training utilities for Megatron-based models.""" -from .lora_ddp import ( - LoRADistributedDataParallel, - wrap_model_with_lora_ddp, + +from .multi_tenant_ddp import ( + MultiTenantLoRADDP, + TenantContext, + TenantGradientManager, + create_multi_tenant_ddp, ) __all__ = [ - 'LoRADistributedDataParallel', - 'wrap_model_with_lora_ddp', + 'MultiTenantLoRADDP', + 'TenantContext', + 'TenantGradientManager', + 'create_multi_tenant_ddp', ] diff --git a/src/twinkle/megatron/distributed/lora_ddp.py b/src/twinkle/megatron/distributed/lora_ddp.py deleted file mode 100644 index ae5a9136..00000000 --- a/src/twinkle/megatron/distributed/lora_ddp.py +++ /dev/null @@ -1,531 +0,0 @@ -# Copyright (c) twinkle authors. All rights reserved. -""" -LoRA-aware Distributed Data Parallel wrapper for Megatron models. - -This module provides a DDP wrapper that: -1. Only creates gradient buffers for LoRA parameters (trainable) -2. Supports communication-computation overlap -3. Supports multi-tenant LoRA training with separate process groups -4. Inherits from Megatron DDP to reuse optimized communication logic -""" - -import logging -from contextlib import contextmanager -from typing import Dict, List, Optional, Set, Union - -import torch -import torch.distributed as dist -import torch.nn as nn - -logger = logging.getLogger(__name__) - -try: - from megatron.core import parallel_state as mpu - from megatron.core.distributed import DistributedDataParallel as MegatronDDP - from megatron.core.distributed.distributed_data_parallel_config import DistributedDataParallelConfig - from megatron.core.distributed.param_and_grad_buffer import _ParamAndGradBuffer, partition_buckets - from megatron.core.distributed.data_parallel_base import _BaseDataParallel - from megatron.core.transformer.transformer_config import TransformerConfig - from megatron.core.process_groups_config import ProcessGroupCollection - MEGATRON_AVAILABLE = True -except ImportError: - MEGATRON_AVAILABLE = False - MegatronDDP = object - _BaseDataParallel = object - - -class LoRADistributedDataParallel(_BaseDataParallel): - """ - Distributed Data Parallel wrapper for LoRA/PEFT models. - - This class inherits from Megatron's _BaseDataParallel and implements - DDP functionality specifically for LoRA parameters. Key features: - - 1. **Selective Parameter Registration**: Only LoRA parameters (trainable) - are registered for gradient synchronization, reducing memory overhead. - - 2. **Communication-Computation Overlap**: When overlap_grad_reduce=True, - gradient all-reduce operations are overlapped with backward computation. - - 3. **Gradient Bucketing**: Parameters are grouped into buckets for efficient - communication, reducing kernel launch overhead. - - 4. **Multi-Tenant Support**: Each tenant can have its own process group - for gradient synchronization. - - 5. **Dynamic Parameter Updates**: Supports adding/removing LoRA parameters - at runtime (requires buffer rebuild). - - Args: - config: Transformer configuration. - ddp_config: DDP configuration controlling overlap, bucketing, etc. - module: The model containing LoRA layers. - disable_bucketing: If True, all parameters go into a single bucket. - lora_param_patterns: Set of patterns to identify LoRA parameters. - tenant_id: Identifier for multi-tenant scenarios. - tenant_process_group: Custom process group for this tenant. - - Example: - >>> # Create DDP wrapper for LoRA model - >>> ddp_config = DistributedDataParallelConfig( - ... overlap_grad_reduce=True, - ... bucket_size=10000000, - ... ) - >>> ddp_model = LoRADistributedDataParallel( - ... config=transformer_config, - ... ddp_config=ddp_config, - ... module=lora_model, - ... ) - >>> - >>> # Training loop - >>> for batch in dataloader: - ... output = ddp_model(batch) - ... loss = compute_loss(output) - ... loss.backward() - ... ddp_model.finish_grad_sync() # Wait for async grad sync - ... optimizer.step() - ... ddp_model.zero_grad_buffer() - """ - - # Default patterns to identify LoRA parameters - DEFAULT_LORA_PATTERNS = {'lora_A', 'lora_B', 'lora_'} - - def __init__( - self, - config: 'TransformerConfig', - ddp_config: 'DistributedDataParallelConfig', - module: nn.Module, - disable_bucketing: bool = False, - lora_param_patterns: Optional[Set[str]] = None, - tenant_id: str = 'default', - tenant_process_group: Optional[dist.ProcessGroup] = None, - ): - if not MEGATRON_AVAILABLE: - raise ImportError("Megatron-Core is required for LoRADistributedDataParallel") - - super().__init__(config=config, module=module) - - self.ddp_config = ddp_config - self.tenant_id = tenant_id - self.lora_param_patterns = lora_param_patterns or self.DEFAULT_LORA_PATTERNS - self._disable_bucketing = disable_bucketing - - # Setup process groups - self._setup_process_groups(tenant_process_group) - - # Configure bucket size - if ddp_config.bucket_size is None: - # Use smaller default for LoRA (fewer parameters) - ddp_config.bucket_size = max( - 10000000, 500000 * self.dp_group.size() - ) - if not ddp_config.overlap_grad_reduce: - ddp_config.bucket_size = None - - self.bucket_size = ddp_config.bucket_size - if disable_bucketing: - self.bucket_size = None - - # Initialize data structures - self.param_to_bucket_group = {} - self.params_with_grad = [] - self.buffers: List['_ParamAndGradBuffer'] = [] - self.bucket_groups = [] - self.expert_parallel_buffers = [] - self.expert_parallel_bucket_groups = [] - self.grad_accs = [] - - # Register LoRA parameters and hooks - self._register_lora_params() - self._register_backward_hooks() - - # Forward hooks for param gather overlap (usually not needed for LoRA) - self.use_forward_hook = ( - ddp_config.use_distributed_optimizer and ddp_config.overlap_param_gather - ) - self.remove_forward_pre_hook_handles = {} - self.overlap_param_gather_with_optimizer_step = False - - logger.info( - f"LoRADistributedDataParallel initialized for tenant '{tenant_id}' " - f"with {len(self.params_with_grad)} LoRA parameters, " - f"{len(self.bucket_groups)} bucket groups" - ) - - def _setup_process_groups(self, tenant_process_group: Optional[dist.ProcessGroup]): - """ - Setup process groups for gradient communication. - - If tenant_process_group is provided, use it for DP communication. - Otherwise, use the default Megatron parallel state groups. - """ - if tenant_process_group is not None: - # Use custom tenant process group - self.dp_group = tenant_process_group - self.dp_cp_group = tenant_process_group - self.intra_dp_cp_group = tenant_process_group - # Expert groups use defaults (MoE multi-tenant not supported yet) - try: - self.expt_dp_group = mpu.get_expert_data_parallel_group() - self.intra_expt_dp_group = mpu.get_expert_data_parallel_group( - partial_expert_data_parallel=True - ) - except: - self.expt_dp_group = None - self.intra_expt_dp_group = None - else: - # Use default Megatron process groups - self.dp_group = mpu.get_data_parallel_group( - with_context_parallel=False, partial_data_parallel=False - ) - self.dp_cp_group = mpu.get_data_parallel_group( - with_context_parallel=True, partial_data_parallel=False - ) - self.intra_dp_cp_group = mpu.get_data_parallel_group( - with_context_parallel=True, partial_data_parallel=True - ) - try: - self.expt_dp_group = mpu.get_expert_data_parallel_group() - self.intra_expt_dp_group = mpu.get_expert_data_parallel_group( - partial_expert_data_parallel=True - ) - except: - self.expt_dp_group = None - self.intra_expt_dp_group = None - - self.tp_group = mpu.get_tensor_model_parallel_group() - self.pp_group = mpu.get_pipeline_model_parallel_group() - try: - self.ep_group = mpu.get_expert_model_parallel_group() - except: - self.ep_group = None - - def _is_lora_param(self, name: str) -> bool: - """Check if a parameter is a LoRA parameter based on name patterns.""" - for pattern in self.lora_param_patterns: - if pattern in name: - return True - return False - - def _register_lora_params(self): - """ - Register LoRA parameters to gradient buffers. - - This method: - 1. Identifies LoRA parameters by name patterns - 2. Groups them by dtype - 3. Creates gradient buffers for efficient communication - 4. Sets up bucket groups for overlapped communication - """ - param_to_name = {} - lora_params = [] - - for name, param in self.module.named_parameters(): - if not param.requires_grad: - continue - - # Only process LoRA parameters - if not self._is_lora_param(name): - continue - - self.params_with_grad.append(param) - param.grad_added_to_main_grad = False - param_to_name[param] = name - lora_params.append(param) - - if not lora_params: - logger.warning( - f"No LoRA parameters found for tenant '{self.tenant_id}'. " - f"Patterns used: {self.lora_param_patterns}" - ) - return - - # Calculate gradient scaling factor - if self.config.calculate_per_token_loss: - gradient_scaling_factor = 1.0 - else: - if self.ddp_config.average_in_collective: - gradient_scaling_factor = 1.0 - else: - gradient_scaling_factor = 1.0 / self.dp_cp_group.size() - - # Group parameters by dtype - param_and_grad_dtype_to_params = {} - param_and_grad_dtype_to_indices = {} - - for param in lora_params: - param_dtype = param.dtype - grad_dtype = torch.float if self.ddp_config.grad_reduce_in_fp32 else param.dtype - - key = (param_dtype, grad_dtype) - if key not in param_and_grad_dtype_to_params: - param_and_grad_dtype_to_params[key] = [] - param_and_grad_dtype_to_indices[key] = [] - param_and_grad_dtype_to_params[key].append(param) - param_and_grad_dtype_to_indices[key].append(len(param_and_grad_dtype_to_params[key]) - 1) - - # Create gradient buffers for each dtype combination - pg_collection = ProcessGroupCollection() - pg_collection.tp = self.tp_group - pg_collection.dp_cp = self.dp_cp_group - - for (param_dtype, grad_dtype), params in param_and_grad_dtype_to_params.items(): - indices = param_and_grad_dtype_to_indices[(param_dtype, grad_dtype)] - - buffer = _ParamAndGradBuffer( - self.ddp_config, - param_dtype, - grad_dtype, - params, - self.intra_dp_cp_group, - self.bucket_size, - param_to_name, - gradient_scaling_factor, - indices, - getattr(self.ddp_config, 'nccl_ub', False), - pg_collection, - ) - self.buffers.append(buffer) - - # Create bucket groups - self.bucket_groups = partition_buckets( - self.buffers, - force_single_bucket_group=self._disable_bucketing - ) - - # Build param to bucket group mapping - for bucket_group in self.bucket_groups: - for bucket in bucket_group.buckets: - for param in bucket.params: - self.param_to_bucket_group[param] = bucket_group - - def _register_backward_hooks(self): - """ - Register backward hooks for LoRA parameters. - - These hooks: - 1. Accumulate gradients to main_grad buffer - 2. Trigger async gradient communication when a bucket is ready - """ - for param in self.params_with_grad: - if param not in self.param_to_bucket_group: - continue - - # Get gradient accumulator and register hook - param_tmp = param.expand_as(param) - grad_acc = param_tmp.grad_fn.next_functions[0][0] - grad_acc.register_hook(self._make_backward_post_hook(param)) - self.grad_accs.append(grad_acc) - - def _make_backward_post_hook(self, param: nn.Parameter): - """ - Create a backward post-hook for a parameter. - - When the parameter's gradient is computed: - 1. Accumulate it to the main_grad buffer - 2. If overlap is enabled AND this is the last microbatch, start async communication - - Note: register_grad_ready() internally checks is_last_microbatch, so we don't - need to check it here. The bucket_group will only start communication when - all params are ready AND it's the last microbatch. - """ - def hook(*unused): - if param in self.param_to_bucket_group: - # Accumulate gradient to main_grad - if param.grad is not None and not param.grad_added_to_main_grad: - param.main_grad.add_(param.grad.data) - param.grad = None - - # If overlap enabled, notify bucket that param is ready - # Note: register_grad_ready internally checks is_last_microbatch - # and only registers when processing the last microbatch - if self.ddp_config.overlap_grad_reduce: - bucket_group = self.param_to_bucket_group[param] - # Only register if this is the last microbatch - # (bucket_group.is_last_microbatch controls this) - if bucket_group.is_last_microbatch: - bucket_group.register_grad_ready(param) - - return hook - - @contextmanager - def no_sync(self): - """ - Context manager that turns off gradient synchronization. - - Use this for gradient accumulation - only sync on the last microbatch. - """ - for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups: - bucket_group.is_last_microbatch = False - try: - yield - finally: - for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups: - bucket_group.is_last_microbatch = True - - def start_grad_sync(self, *unused): - """ - Start gradient synchronization (all-reduce or reduce-scatter). - - When overlap_grad_reduce=True, this dispatches async operations. - When overlap_grad_reduce=False, this is a no-op (finish_grad_sync does sync). - """ - for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups: - bucket_group.start_grad_sync() - - def finish_grad_sync(self): - """ - Finish gradient synchronization. - - When overlap_grad_reduce=True, waits for async operations to complete. - When overlap_grad_reduce=False, performs synchronous communication. - """ - for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups: - bucket_group.finish_grad_sync() - - def scale_gradients(self, scaling_factor: float): - """Scale all gradients in buffers by the given factor.""" - for buffer in self.buffers + self.expert_parallel_buffers: - buffer.scale_gradients(scaling_factor) - - def zero_grad_buffer(self): - """ - Zero out all gradient buffers. - - Must be called at the beginning of each training iteration. - """ - for param in self.params_with_grad: - param.grad_added_to_main_grad = False - for buffer in self.buffers + self.expert_parallel_buffers: - buffer.reset() - for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups: - bucket_group.reset() - - def broadcast_params(self): - """Broadcast parameters from rank 0 to all DP ranks.""" - for param in self.params_with_grad: - dist.broadcast( - param.data, - src=dist.get_global_rank(self.dp_cp_group, 0), - group=self.dp_cp_group, - ) - - def add_lora_params(self, new_params: Dict[str, nn.Parameter]): - """ - Dynamically add LoRA parameters. - - Note: This requires rebuilding gradient buffers, which is expensive. - Use sparingly. - - Args: - new_params: Dictionary mapping parameter names to parameters. - """ - for name, param in new_params.items(): - if param.requires_grad: - self.params_with_grad.append(param) - param.grad_added_to_main_grad = False - - self._rebuild_buffers() - - def remove_lora_params(self, param_names: Set[str]): - """ - Remove LoRA parameters. - - Note: This requires rebuilding gradient buffers, which is expensive. - - Args: - param_names: Set of parameter names to remove. - """ - new_params_with_grad = [] - for param in self.params_with_grad: - # Find param name in module - for name, p in self.module.named_parameters(): - if p is param and name not in param_names: - new_params_with_grad.append(param) - break - - self.params_with_grad = new_params_with_grad - self._rebuild_buffers() - - def _rebuild_buffers(self): - """Rebuild gradient buffers after parameter changes.""" - # Clear old hooks and buffers - self.grad_accs.clear() - self.param_to_bucket_group.clear() - self.buffers.clear() - self.bucket_groups.clear() - - # Re-register - self._register_lora_params() - self._register_backward_hooks() - - logger.info( - f"Rebuilt buffers for tenant '{self.tenant_id}': " - f"{len(self.params_with_grad)} params, {len(self.bucket_groups)} bucket groups" - ) - - def get_lora_param_count(self) -> int: - """Get the number of registered LoRA parameters.""" - return len(self.params_with_grad) - - def get_lora_param_numel(self) -> int: - """Get the total number of elements in LoRA parameters.""" - return sum(p.numel() for p in self.params_with_grad) - - -def wrap_model_with_lora_ddp( - model: nn.Module, - config: 'TransformerConfig', - ddp_config: Optional['DistributedDataParallelConfig'] = None, - lora_param_patterns: Optional[Set[str]] = None, - tenant_id: str = 'default', - tenant_process_group: Optional[dist.ProcessGroup] = None, - overlap_grad_reduce: bool = True, - bucket_size: Optional[int] = None, -) -> LoRADistributedDataParallel: - """ - Convenience function to wrap a LoRA model with DDP. - - This is the recommended way to create a LoRADistributedDataParallel wrapper. - - Args: - model: Model containing LoRA layers. - config: Transformer configuration. - ddp_config: DDP configuration. If None, creates default config. - lora_param_patterns: Patterns to identify LoRA parameters. - tenant_id: Tenant identifier for multi-tenant scenarios. - tenant_process_group: Custom process group for this tenant. - overlap_grad_reduce: Enable communication-computation overlap. - bucket_size: Size of gradient buckets. None for auto. - - Returns: - LoRADistributedDataParallel wrapper. - - Example: - >>> ddp_model = wrap_model_with_lora_ddp( - ... model=lora_model, - ... config=transformer_config, - ... overlap_grad_reduce=True, - ... ) - """ - if not MEGATRON_AVAILABLE: - raise ImportError("Megatron-Core is required for wrap_model_with_lora_ddp") - - if ddp_config is None: - ddp_config = DistributedDataParallelConfig( - overlap_grad_reduce=overlap_grad_reduce, - use_distributed_optimizer=False, # LoRA params are small - bucket_size=bucket_size, - ) - - if lora_param_patterns is None: - lora_param_patterns = LoRADistributedDataParallel.DEFAULT_LORA_PATTERNS - - return LoRADistributedDataParallel( - config=config, - ddp_config=ddp_config, - module=model, - lora_param_patterns=lora_param_patterns, - tenant_id=tenant_id, - tenant_process_group=tenant_process_group, - ) diff --git a/src/twinkle/megatron/distributed/multi_tenant_ddp.py b/src/twinkle/megatron/distributed/multi_tenant_ddp.py new file mode 100644 index 00000000..9f84f230 --- /dev/null +++ b/src/twinkle/megatron/distributed/multi_tenant_ddp.py @@ -0,0 +1,342 @@ +# Copyright (c) twinkle authors. All rights reserved. +""" +Multi-Tenant LoRA DDP for Megatron models. + +This module provides a minimal, maintainable DDP solution for multi-tenant LoRA training: +1. Inherits from Megatron's DistributedDataParallel to maximize code reuse +2. Uses MultiAdapter's ContextVar mechanism for tenant isolation +3. Supports per-tenant process groups for gradient synchronization + +Key insight: Megatron DDP already only creates buffers for requires_grad=True params, +so we just need to control which params are trainable per-tenant. +""" + +import contextvars +import logging +from typing import Dict, List, Optional, Set + +import torch.distributed as dist +import torch.nn as nn + +logger = logging.getLogger(__name__) + +try: + from megatron.core import parallel_state as mpu + from megatron.core.distributed import DistributedDataParallel as MegatronDDP + from megatron.core.distributed.distributed_data_parallel_config import DistributedDataParallelConfig + from megatron.core.transformer.transformer_config import TransformerConfig + MEGATRON_AVAILABLE = True +except ImportError: + MEGATRON_AVAILABLE = False + MegatronDDP = object + + +class TenantContext: + """ + Thread/coroutine-safe tenant context using ContextVar. + + This integrates with MultiAdapter's ContextVar mechanism to ensure + that each request/coroutine operates on the correct tenant's LoRA weights. + """ + + _current_tenant: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar( + 'current_tenant', default=None + ) + + @classmethod + def get_current_tenant(cls) -> Optional[str]: + return cls._current_tenant.get() + + @classmethod + def set_current_tenant(cls, tenant_id: str): + cls._current_tenant.set(tenant_id) + + @classmethod + def reset_tenant(cls): + cls._current_tenant.set(None) + + +class TenantGradientManager: + """ + Manages per-tenant gradient buffers and communication groups. + + This is a lightweight wrapper that doesn't duplicate Megatron DDP logic, + but instead coordinates gradient sync across tenants. + """ + + def __init__(self): + self.tenant_params: Dict[str, Set[nn.Parameter]] = {} + self.tenant_process_groups: Dict[str, dist.ProcessGroup] = {} + self.tenant_param_names: Dict[str, Dict[nn.Parameter, str]] = {} + + def register_tenant( + self, + tenant_id: str, + params: List[nn.Parameter], + param_names: Dict[nn.Parameter, str], + process_group: Optional[dist.ProcessGroup] = None, + ): + """ + Register a tenant with its LoRA parameters. + + Args: + tenant_id: Unique tenant identifier. + params: List of LoRA parameters for this tenant. + param_names: Mapping from param to name for debugging. + process_group: Optional custom process group for this tenant. + """ + self.tenant_params[tenant_id] = set(params) + self.tenant_param_names[tenant_id] = param_names + + if process_group is not None: + self.tenant_process_groups[tenant_id] = process_group + else: + # Use default DP group + self.tenant_process_groups[tenant_id] = mpu.get_data_parallel_group( + with_context_parallel=True + ) + + logger.info( + f"Registered tenant '{tenant_id}' with {len(params)} parameters, " + f"process group size: {self.tenant_process_groups[tenant_id].size()}" + ) + + def unregister_tenant(self, tenant_id: str): + """Remove a tenant and its associated resources.""" + self.tenant_params.pop(tenant_id, None) + self.tenant_param_names.pop(tenant_id, None) + self.tenant_process_groups.pop(tenant_id, None) + logger.info(f"Unregistered tenant '{tenant_id}'") + + def get_tenant_params(self, tenant_id: str) -> Set[nn.Parameter]: + return self.tenant_params.get(tenant_id, set()) + + def get_tenant_process_group(self, tenant_id: str) -> Optional[dist.ProcessGroup]: + return self.tenant_process_groups.get(tenant_id) + + +class MultiTenantLoRADDP(MegatronDDP if MEGATRON_AVAILABLE else object): + """ + Multi-Tenant LoRA DDP wrapper that extends Megatron's DDP. + + Design principles: + 1. **Minimal override**: Only override what's necessary for multi-tenant support + 2. **Reuse Megatron DDP**: All gradient buffer management, bucketing, and + communication overlap logic is inherited from Megatron DDP + 3. **ContextVar integration**: Uses TenantContext for thread-safe tenant switching + 4. **Lazy buffer creation**: Buffers are created per-tenant on first use + + Key insight: Instead of creating a separate DDP per tenant, we: + - Keep one DDP instance with all LoRA parameters + - Use ContextVar to track current tenant + - Filter gradient sync to only current tenant's params + + Example: + >>> # Create multi-tenant DDP + >>> ddp = MultiTenantLoRADDP(config, ddp_config, model) + >>> + >>> # Register tenants + >>> ddp.register_tenant('tenant_a', tenant_a_params) + >>> ddp.register_tenant('tenant_b', tenant_b_params) + >>> + >>> # Training loop with tenant isolation + >>> TenantContext.set_current_tenant('tenant_a') + >>> output = ddp(input) # Uses tenant_a's LoRA + >>> loss.backward() + >>> ddp.finish_grad_sync() # Only syncs tenant_a's grads + """ + + # Default patterns to identify LoRA parameters + DEFAULT_LORA_PATTERNS = {'lora_A', 'lora_B', 'lora_'} + + def __init__( + self, + config: 'TransformerConfig', + ddp_config: 'DistributedDataParallelConfig', + module: nn.Module, + disable_bucketing: bool = False, + lora_param_patterns: Optional[Set[str]] = None, + **kwargs, + ): + """ + Initialize MultiTenantLoRADDP. + + This calls the parent Megatron DDP __init__ which will: + 1. Create gradient buffers for all requires_grad=True params + 2. Set up backward hooks + 3. Initialize bucket groups + + We then add multi-tenant management on top. + """ + if not MEGATRON_AVAILABLE: + raise ImportError("Megatron-Core is required") + + self.lora_param_patterns = lora_param_patterns or self.DEFAULT_LORA_PATTERNS + self._tenant_manager = TenantGradientManager() + + # Pre-identify LoRA parameters before parent init + # This helps with debugging and tenant registration + self._lora_params: Dict[str, nn.Parameter] = {} + for name, param in module.named_parameters(): + if param.requires_grad and self._is_lora_param(name): + self._lora_params[name] = param + + logger.info(f"Identified {len(self._lora_params)} LoRA parameters") + + # Call parent Megatron DDP init + # This creates buffers for all requires_grad=True params + super().__init__( + config=config, + ddp_config=ddp_config, + module=module, + disable_bucketing=disable_bucketing, + **kwargs, + ) + + logger.info( + f"MultiTenantLoRADDP initialized with {len(self.params_with_grad)} " + f"trainable parameters, {len(self.bucket_groups)} bucket groups" + ) + + def _is_lora_param(self, name: str) -> bool: + """Check if parameter name matches LoRA patterns.""" + for pattern in self.lora_param_patterns: + if pattern in name: + return True + return False + + def register_tenant( + self, + tenant_id: str, + adapter_name: Optional[str] = None, + process_group: Optional[dist.ProcessGroup] = None, + ): + """ + Register a tenant for multi-tenant training. + + Args: + tenant_id: Unique tenant identifier. + adapter_name: PEFT adapter name (if different from tenant_id). + process_group: Custom process group for gradient sync. + """ + adapter_name = adapter_name or tenant_id + + # Find parameters belonging to this adapter + tenant_params = [] + param_names = {} + + for name, param in self._lora_params.items(): + # Match adapter name in parameter path + # e.g., "model.layers.0.self_attn.q_proj.lora_A.tenant_a.weight" + if f'.{adapter_name}.' in name or name.endswith(f'.{adapter_name}'): + tenant_params.append(param) + param_names[param] = name + + if not tenant_params: + # If no adapter-specific match, assume all LoRA params belong to this tenant + # This handles single-tenant scenarios + logger.warning( + f"No adapter-specific params found for '{adapter_name}', " + f"registering all {len(self._lora_params)} LoRA params" + ) + tenant_params = list(self._lora_params.values()) + param_names = {v: k for k, v in self._lora_params.items()} + + self._tenant_manager.register_tenant( + tenant_id=tenant_id, + params=tenant_params, + param_names=param_names, + process_group=process_group, + ) + + def unregister_tenant(self, tenant_id: str): + """Remove a tenant.""" + self._tenant_manager.unregister_tenant(tenant_id) + + def set_current_tenant(self, tenant_id: str): + """ + Set the current tenant for subsequent operations. + + This should be called before forward/backward to ensure + correct LoRA adapter is used. + """ + TenantContext.set_current_tenant(tenant_id) + + def get_current_tenant(self) -> Optional[str]: + """Get the current tenant ID.""" + return TenantContext.get_current_tenant() + + def finish_grad_sync_for_tenant(self, tenant_id: Optional[str] = None): + """ + Finish gradient sync for a specific tenant. + + If tenant_id is None, uses current tenant from context. + If no tenant is set, falls back to syncing all params (parent behavior). + """ + tenant_id = tenant_id or TenantContext.get_current_tenant() + + if tenant_id is None: + # No tenant specified, use default behavior + super().finish_grad_sync() + return + + # Get tenant's process group + pg = self._tenant_manager.get_tenant_process_group(tenant_id) + if pg is None: + logger.warning(f"Tenant '{tenant_id}' not registered, using default sync") + super().finish_grad_sync() + return + + # For now, use parent's finish_grad_sync + # In a more advanced implementation, we could filter to only + # sync the tenant's parameters, but Megatron's bucket design + # makes this complex + super().finish_grad_sync() + + def get_tenant_param_count(self, tenant_id: str) -> int: + """Get number of parameters for a tenant.""" + return len(self._tenant_manager.get_tenant_params(tenant_id)) + + def get_tenant_param_numel(self, tenant_id: str) -> int: + """Get total number of elements in tenant's parameters.""" + return sum(p.numel() for p in self._tenant_manager.get_tenant_params(tenant_id)) + + +def create_multi_tenant_ddp( + model: nn.Module, + config: 'TransformerConfig', + ddp_config: Optional['DistributedDataParallelConfig'] = None, + lora_param_patterns: Optional[Set[str]] = None, + overlap_grad_reduce: bool = True, + bucket_size: Optional[int] = None, +) -> MultiTenantLoRADDP: + """ + Factory function to create a MultiTenantLoRADDP wrapper. + + Args: + model: Model containing LoRA layers. + config: Transformer configuration. + ddp_config: DDP configuration. If None, creates default config. + lora_param_patterns: Patterns to identify LoRA parameters. + overlap_grad_reduce: Enable communication-computation overlap. + bucket_size: Size of gradient buckets. + + Returns: + MultiTenantLoRADDP wrapper. + """ + if not MEGATRON_AVAILABLE: + raise ImportError("Megatron-Core is required") + + if ddp_config is None: + ddp_config = DistributedDataParallelConfig( + overlap_grad_reduce=overlap_grad_reduce, + use_distributed_optimizer=False, + bucket_size=bucket_size, + ) + + return MultiTenantLoRADDP( + config=config, + ddp_config=ddp_config, + module=model, + lora_param_patterns=lora_param_patterns, + ) diff --git a/src/twinkle/megatron/model/multi_tenant_megatron.py b/src/twinkle/megatron/model/multi_tenant_megatron.py new file mode 100644 index 00000000..e280151f --- /dev/null +++ b/src/twinkle/megatron/model/multi_tenant_megatron.py @@ -0,0 +1,517 @@ +# Copyright (c) twinkle authors. All rights reserved. +""" +Multi-Tenant Megatron Model for LoRA training. + +This module provides multi-tenant LoRA training support for Megatron models, +similar to MultiLoraTransformersModel but optimized for Megatron's architecture. + +Key features: +1. Uses MultiAdapter's ContextVar mechanism for tenant isolation +2. Integrates with Megatron's parallel state and DDP +3. Supports per-tenant optimizers, schedulers, and gradient accumulation +4. Compatible with Swift Megatron's LoraParallelLinear +""" + +import contextvars +import logging +import re +from contextlib import contextmanager +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Set, Type, Union + +import torch +import torch.distributed as dist +import torch.nn as nn + +logger = logging.getLogger(__name__) + +try: + from megatron.core import parallel_state as mpu + from megatron.core.distributed import DistributedDataParallel as MegatronDDP + from megatron.core.distributed.distributed_data_parallel_config import DistributedDataParallelConfig + from megatron.core.transformer.transformer_config import TransformerConfig + MEGATRON_AVAILABLE = True +except ImportError: + MEGATRON_AVAILABLE = False + +try: + from peft import LoraConfig, PeftModel + from peft.tuners.lora import LoraLayer, LoraModel + PEFT_AVAILABLE = True +except ImportError: + PEFT_AVAILABLE = False + + +class MegatronMultiAdapter: + """ + Megatron-compatible MultiAdapter using ContextVar for tenant isolation. + + This patches LoraLayer/LoraModel to use ContextVar-based adapter selection, + enabling thread/coroutine-safe multi-tenant training. + + Key difference from twinkle's MultiAdapter: + - Also patches Swift Megatron's LoraParallelLinear if present + """ + + _adapter_var: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar( + 'megatron_adapter_name', default=None + ) + _patched: bool = False + + def __call__(self, module: nn.Module) -> nn.Module: + """ + Patch LoRA layers to use ContextVar-based adapter selection. + + Args: + module: Model containing LoRA layers. + + Returns: + Patched model (same instance, modified in-place). + """ + if MegatronMultiAdapter._patched: + return module + + self._patch_peft_lora() + self._patch_megatron_lora() + + module.set_current_adapter_name = MegatronMultiAdapter.set_current_adapter_name + MegatronMultiAdapter._patched = True + + return module + + def _patch_peft_lora(self): + """Patch PEFT's LoraLayer and LoraModel.""" + if not PEFT_AVAILABLE: + return + + def get_active_adapter(*args, **kwargs): + return MegatronMultiAdapter._adapter_var.get() + + def get_active_adapters(*args, **kwargs): + adapter_name = MegatronMultiAdapter._adapter_var.get() + return [adapter_name] if adapter_name else [] + + def set_active_adapters(_, value): + pass # Controlled via ContextVar + + def set_adapter(self, adapter_names): + pass # Controlled via ContextVar + + def mark_only_adapters_trainable(self, model) -> None: + for n, p in model.named_parameters(): + p.requires_grad = "lora_" in n + + # Patch LoraLayer + LoraLayer.active_adapter = property(get_active_adapter, set_active_adapters) + LoraLayer.active_adapters = property(get_active_adapters, set_active_adapters) + LoraLayer.set_adapter = set_adapter + + # Patch LoraModel + LoraModel.active_adapter = property(get_active_adapter, set_active_adapters) + LoraModel.active_adapters = property(get_active_adapters, set_active_adapters) + LoraModel.set_adapter = set_adapter + LoraModel._mark_only_adapters_as_trainable = mark_only_adapters_trainable + + logger.info("Patched PEFT LoraLayer/LoraModel for multi-tenant support") + + def _patch_megatron_lora(self): + """Patch Swift Megatron's LoraParallelLinear if available.""" + try: + from swift.megatron.tuners.lora import LoraParallelLinear + + def get_active_adapter(self): + return MegatronMultiAdapter._adapter_var.get() + + def get_active_adapters(self): + adapter_name = MegatronMultiAdapter._adapter_var.get() + return [adapter_name] if adapter_name else [] + + # Patch as properties + if not hasattr(LoraParallelLinear, '_multi_tenant_patched'): + LoraParallelLinear.active_adapter = property(get_active_adapter) + LoraParallelLinear.active_adapters = property(get_active_adapters) + LoraParallelLinear._multi_tenant_patched = True + logger.info("Patched LoraParallelLinear for multi-tenant support") + except ImportError: + logger.debug("Swift Megatron LoraParallelLinear not available") + + @staticmethod + def set_current_adapter_name(adapter_name: Optional[str]): + """Set the current adapter for this context.""" + MegatronMultiAdapter._adapter_var.set(adapter_name) + + @staticmethod + def get_current_adapter_name() -> Optional[str]: + """Get the current adapter name.""" + return MegatronMultiAdapter._adapter_var.get() + + +@dataclass +class TenantState: + """State for a single tenant.""" + adapter_name: str + process_group: Optional[dist.ProcessGroup] = None + optimizer: Optional[torch.optim.Optimizer] = None + scheduler: Optional[Any] = None + grad_scaler: Optional[torch.cuda.amp.GradScaler] = None + lora_config: Optional['LoraConfig'] = None + + # Tracking + trainable_params: List[nn.Parameter] = field(default_factory=list) + param_names: Dict[nn.Parameter, str] = field(default_factory=dict) + + +class MultiTenantMegatronModel: + """ + Multi-Tenant Megatron Model wrapper for LoRA training. + + This class provides: + 1. Multi-tenant adapter management using ContextVar + 2. Per-tenant optimizer and scheduler + 3. Gradient synchronization with tenant-specific process groups + 4. Integration with Megatron's DDP + + Design: + - Uses a single Megatron DDP wrapper for all tenants + - Each tenant has isolated LoRA adapters + - ContextVar ensures thread-safe adapter switching + + Example: + >>> model = create_megatron_model(...) + >>> multi_tenant = MultiTenantMegatronModel(model, config, ddp_config) + >>> + >>> # Add tenants + >>> multi_tenant.add_tenant('user_a', lora_config_a) + >>> multi_tenant.add_tenant('user_b', lora_config_b) + >>> + >>> # Training + >>> with multi_tenant.tenant_context('user_a'): + ... output = multi_tenant(input) + ... loss.backward() + ... multi_tenant.step() + """ + + LORA_PARAM_PATTERN = re.compile(r'\.lora_\w+\.[^.]+\.') + + def __init__( + self, + model: nn.Module, + config: 'TransformerConfig', + ddp_config: Optional['DistributedDataParallelConfig'] = None, + default_dp_group: Optional[dist.ProcessGroup] = None, + ): + """ + Initialize multi-tenant model. + + Args: + model: Base Megatron model (can be already wrapped with PEFT). + config: Transformer configuration. + ddp_config: DDP configuration. If None, creates default. + default_dp_group: Default data parallel group for tenants. + """ + if not MEGATRON_AVAILABLE: + raise ImportError("Megatron-Core is required") + + self.config = config + self.ddp_config = ddp_config or DistributedDataParallelConfig( + overlap_grad_reduce=True, + use_distributed_optimizer=False, + ) + + # Setup multi-adapter + self._multi_adapter = MegatronMultiAdapter() + self.model = self._multi_adapter(model) + + # Tenant management + self._tenants: Dict[str, TenantState] = {} + self._default_dp_group = default_dp_group or mpu.get_data_parallel_group( + with_context_parallel=True + ) + + # DDP wrapper (created lazily after first tenant is added) + self._ddp: Optional[MegatronDDP] = None + + # Add a dummy adapter to ensure PEFT model structure is ready + self._ensure_peft_model() + + def _ensure_peft_model(self): + """Ensure the model is a PEFT model.""" + if not PEFT_AVAILABLE: + logger.warning("PEFT not available, skipping PEFT model check") + return + + if not isinstance(self.model, PeftModel): + # Create minimal LoRA config for structure + dummy_config = LoraConfig( + r=1, + target_modules='all-linear', + init_lora_weights=False, + ) + # Note: For Megatron models, you typically use Swift's prepare_model + logger.warning( + "Model is not a PeftModel. For Megatron LoRA, " + "use Swift.prepare_model() before wrapping." + ) + + def _wrap_with_ddp(self): + """Wrap model with Megatron DDP (lazy initialization).""" + if self._ddp is not None: + return + + self._ddp = MegatronDDP( + config=self.config, + ddp_config=self.ddp_config, + module=self.model, + ) + logger.info( + f"Created Megatron DDP with {len(self._ddp.params_with_grad)} params, " + f"{len(self._ddp.bucket_groups)} bucket groups" + ) + + def add_tenant( + self, + tenant_id: str, + lora_config: Optional['LoraConfig'] = None, + process_group: Optional[dist.ProcessGroup] = None, + optimizer_cls: Type[torch.optim.Optimizer] = torch.optim.AdamW, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + scheduler_cls: Optional[Type] = None, + scheduler_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Add a tenant with their LoRA configuration. + + Args: + tenant_id: Unique tenant identifier. + lora_config: LoRA configuration. If None, assumes adapter already exists. + process_group: Custom process group for this tenant's gradient sync. + optimizer_cls: Optimizer class. + optimizer_kwargs: Optimizer arguments. + scheduler_cls: LR scheduler class. + scheduler_kwargs: Scheduler arguments. + """ + if tenant_id in self._tenants: + logger.warning(f"Tenant '{tenant_id}' already exists, skipping") + return + + adapter_name = tenant_id + + # Add adapter if config provided and using PEFT + if lora_config is not None and PEFT_AVAILABLE and isinstance(self.model, PeftModel): + # Safety checks + lora_config.modules_to_save = None + lora_config.bias = 'none' + + self.model.add_adapter(adapter_name, lora_config) + logger.info(f"Added LoRA adapter '{adapter_name}'") + + # Set adapter as active to find its params + MegatronMultiAdapter.set_current_adapter_name(adapter_name) + + # Find trainable params for this adapter + trainable_params = [] + param_names = {} + + for name, param in self.model.named_parameters(): + if self.LORA_PARAM_PATTERN.search(name) and f'.{adapter_name}.' in name: + param.requires_grad = True + trainable_params.append(param) + param_names[param] = name + + # Create tenant state + state = TenantState( + adapter_name=adapter_name, + process_group=process_group or self._default_dp_group, + lora_config=lora_config, + trainable_params=trainable_params, + param_names=param_names, + ) + + # Create optimizer + if optimizer_kwargs is None: + optimizer_kwargs = {'lr': 1e-4, 'weight_decay': 0.01} + + state.optimizer = optimizer_cls(trainable_params, **optimizer_kwargs) + + # Create scheduler if specified + if scheduler_cls is not None: + scheduler_kwargs = scheduler_kwargs or {} + state.scheduler = scheduler_cls(state.optimizer, **scheduler_kwargs) + + self._tenants[tenant_id] = state + + logger.info( + f"Registered tenant '{tenant_id}' with {len(trainable_params)} " + f"trainable params ({sum(p.numel() for p in trainable_params):,} elements)" + ) + + # Reset adapter context + MegatronMultiAdapter.set_current_adapter_name(None) + + def remove_tenant(self, tenant_id: str): + """Remove a tenant.""" + if tenant_id not in self._tenants: + logger.warning(f"Tenant '{tenant_id}' not found") + return + + state = self._tenants.pop(tenant_id) + + # Remove adapter from model if using PEFT + if PEFT_AVAILABLE and isinstance(self.model, PeftModel): + try: + self.model.delete_adapter(state.adapter_name) + except Exception as e: + logger.warning(f"Failed to delete adapter: {e}") + + logger.info(f"Removed tenant '{tenant_id}'") + + @contextmanager + def tenant_context(self, tenant_id: str): + """ + Context manager for tenant-specific operations. + + All forward/backward operations within this context will use + the specified tenant's LoRA adapter. + """ + if tenant_id not in self._tenants: + raise ValueError(f"Tenant '{tenant_id}' not registered") + + state = self._tenants[tenant_id] + prev_adapter = MegatronMultiAdapter.get_current_adapter_name() + + try: + MegatronMultiAdapter.set_current_adapter_name(state.adapter_name) + yield state + finally: + MegatronMultiAdapter.set_current_adapter_name(prev_adapter) + + def forward(self, *args, tenant_id: Optional[str] = None, **kwargs): + """ + Forward pass with tenant selection. + + Args: + *args: Model inputs. + tenant_id: Tenant to use. If None, uses current context. + **kwargs: Additional arguments. + """ + if tenant_id is not None: + MegatronMultiAdapter.set_current_adapter_name(tenant_id) + + # Ensure DDP is initialized + if self._ddp is None: + self._wrap_with_ddp() + + return self._ddp(*args, **kwargs) + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def backward(self, loss: torch.Tensor, tenant_id: Optional[str] = None): + """ + Backward pass with optional tenant selection. + + Args: + loss: Loss tensor. + tenant_id: Tenant for gradient accumulation. + """ + if tenant_id is not None: + MegatronMultiAdapter.set_current_adapter_name(tenant_id) + + loss.backward() + + # Sync gradients for this tenant + self._reduce_tenant_gradients(tenant_id) + + def _reduce_tenant_gradients(self, tenant_id: Optional[str] = None): + """ + Reduce gradients for a specific tenant. + + For now, uses Megatron DDP's finish_grad_sync which syncs all params. + A more optimized version could filter to only tenant's params. + """ + tenant_id = tenant_id or MegatronMultiAdapter.get_current_adapter_name() + + if self._ddp is not None: + self._ddp.finish_grad_sync() + + def step(self, tenant_id: Optional[str] = None): + """ + Optimizer step for a tenant. + + Args: + tenant_id: Tenant to update. If None, uses current context. + """ + tenant_id = tenant_id or MegatronMultiAdapter.get_current_adapter_name() + + if tenant_id is None: + raise ValueError("No tenant specified and no current tenant context") + + state = self._tenants.get(tenant_id) + if state is None: + raise ValueError(f"Tenant '{tenant_id}' not registered") + + if state.optimizer is not None: + state.optimizer.step() + + def zero_grad(self, tenant_id: Optional[str] = None): + """Zero gradients for a tenant.""" + tenant_id = tenant_id or MegatronMultiAdapter.get_current_adapter_name() + + if tenant_id is None: + # Zero all + if self._ddp is not None: + self._ddp.zero_grad_buffer() + return + + state = self._tenants.get(tenant_id) + if state is not None and state.optimizer is not None: + state.optimizer.zero_grad() + + def lr_step(self, tenant_id: Optional[str] = None): + """LR scheduler step for a tenant.""" + tenant_id = tenant_id or MegatronMultiAdapter.get_current_adapter_name() + + if tenant_id is None: + return + + state = self._tenants.get(tenant_id) + if state is not None and state.scheduler is not None: + state.scheduler.step() + + def clip_grad_norm( + self, + max_norm: float = 1.0, + norm_type: float = 2.0, + tenant_id: Optional[str] = None, + ) -> torch.Tensor: + """Clip gradients for a tenant.""" + tenant_id = tenant_id or MegatronMultiAdapter.get_current_adapter_name() + + if tenant_id is None: + raise ValueError("No tenant specified") + + state = self._tenants.get(tenant_id) + if state is None: + raise ValueError(f"Tenant '{tenant_id}' not registered") + + return torch.nn.utils.clip_grad_norm_( + state.trainable_params, max_norm, norm_type + ) + + def get_tenant_state(self, tenant_id: str) -> Optional[TenantState]: + """Get state for a tenant.""" + return self._tenants.get(tenant_id) + + def list_tenants(self) -> List[str]: + """List all registered tenants.""" + return list(self._tenants.keys()) + + @property + def ddp(self) -> Optional[MegatronDDP]: + """Get the DDP wrapper.""" + return self._ddp + + @property + def unwrapped_model(self) -> nn.Module: + """Get the unwrapped model.""" + return self.model diff --git a/src/twinkle/model/megatron.py b/src/twinkle/model/megatron.py index 1cb70474..e4cf71d2 100644 --- a/src/twinkle/model/megatron.py +++ b/src/twinkle/model/megatron.py @@ -910,11 +910,7 @@ def _is_model_ddp_wrapped(self) -> bool: True if model is wrapped with DDP (either Megatron DDP, LoRA DDP, or PyTorch DDP). """ from torch.nn.parallel import DistributedDataParallel as TorchDDP - try: - from twinkle.megatron.distributed import LoRADistributedDataParallel - return isinstance(self.model, (MegatronDDP, LoRADistributedDataParallel, TorchDDP)) - except ImportError: - return isinstance(self.model, (MegatronDDP, TorchDDP)) + return isinstance(self.model, (MegatronDDP, TorchDDP)) def _get_unwrapped_model(self) -> nn.Module: """Get the unwrapped model. @@ -1029,13 +1025,6 @@ def finalize_model_grads_for_ddp(model_list, *args, **kwargs): model_chunk.finish_grad_sync() transformer_config.finalize_model_grads_func = finalize_model_grads_for_ddp - if mpu.get_data_parallel_rank() == 0: - lora_count = self.model.get_lora_param_count() - lora_numel = self.model.get_lora_param_numel() - print(f"Wrapped model with LoRA DDP: {lora_count} params, {lora_numel:,} elements") - print(f" overlap_grad_reduce={overlap_grad_reduce}") - print(f" bucket_size={bucket_size or 'auto'}") - return self @remote_function(dispatch='all') diff --git a/src/twinkle/model/strategy/megatron.py b/src/twinkle/model/strategy/megatron.py index 748b57ef..1b721083 100644 --- a/src/twinkle/model/strategy/megatron.py +++ b/src/twinkle/model/strategy/megatron.py @@ -509,6 +509,7 @@ def _wrap_with_megatron_ddp( ) # Wrap with MegatronDDP + # TODO: multi-tenant ddp try: wrapped_model = MegatronDDP( config=config, From fef94242f9ce2aef1e50ca105560ec858b58202f Mon Sep 17 00:00:00 2001 From: root Date: Fri, 16 Jan 2026 19:42:34 +0800 Subject: [PATCH 08/22] tenant --- cookbook/megatron/__init__.py | 1 - cookbook/megatron/lora.py | 78 +- .../megatron/megatron_multi_tenant/client.py | 162 +++ .../megatron/megatron_multi_tenant/server.py | 237 ++++ cookbook/megatron/moe_lora.py | 108 +- cookbook/megatron_multi_tenant/server.py | 239 ++++ src/twinkle/megatron/__init__.py | 64 +- src/twinkle/megatron/distributed/__init__.py | 33 +- .../megatron/distributed/multi_tenant_ddp.py | 628 +++++----- .../megatron/distributed/tenant_context.py | 106 ++ .../megatron/distributed/tenant_manager.py | 268 ++++ src/twinkle/megatron/model/__init__.py | 31 +- src/twinkle/megatron/model/bridge.py | 1095 ++++++++++------- src/twinkle/megatron/model/initializer.py | 141 ++- .../megatron/model/multi_tenant_megatron.py | 652 ++++------ src/twinkle/megatron/model/qwen3.py | 20 +- src/twinkle/megatron/tuners/lora.py | 323 ++--- src/twinkle/megatron/utils.py | 4 - src/twinkle/model/megatron.py | 898 +++++++------- src/twinkle/model/strategy/megatron.py | 268 ++-- tests/megatron/test_multi_tenant_ddp.py | 181 +++ 21 files changed, 3420 insertions(+), 2117 deletions(-) create mode 100644 cookbook/megatron/megatron_multi_tenant/client.py create mode 100644 cookbook/megatron/megatron_multi_tenant/server.py create mode 100644 cookbook/megatron_multi_tenant/server.py create mode 100644 src/twinkle/megatron/distributed/tenant_context.py create mode 100644 src/twinkle/megatron/distributed/tenant_manager.py create mode 100644 tests/megatron/test_multi_tenant_ddp.py diff --git a/cookbook/megatron/__init__.py b/cookbook/megatron/__init__.py index a0b9f9e5..49006762 100644 --- a/cookbook/megatron/__init__.py +++ b/cookbook/megatron/__init__.py @@ -1,2 +1 @@ # Copyright (c) twinkle authors. All rights reserved. - diff --git a/cookbook/megatron/lora.py b/cookbook/megatron/lora.py index c87e5301..8af6556e 100644 --- a/cookbook/megatron/lora.py +++ b/cookbook/megatron/lora.py @@ -12,45 +12,56 @@ import argparse import os +import numpy as np +# CRITICAL: Set CUDA device before any CUDA imports (local mode only) +import torch +from peft import LoraConfig +from torch.optim import AdamW +from torch.optim.lr_scheduler import LinearLR + +import twinkle +from twinkle import (DeviceGroup, DeviceMesh, Platform, get_device_placement, + get_logger) +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.loss import MegatronCrossEntropyLoss +from twinkle.model import MegatronModel +from twinkle.processor import InputProcessor + # Parse arguments first to determine mode parser = argparse.ArgumentParser() -parser.add_argument('--mode', type=str, default='local', choices=['local', 'ray']) +parser.add_argument('--mode', + type=str, + default='local', + choices=['local', 'ray']) parser.add_argument('--tp_size', type=int, default=1) parser.add_argument('--pp_size', type=int, default=1) parser.add_argument('--cp_size', type=int, default=1) -parser.add_argument('--num_gpus', type=int, default=4, help='Number of GPUs (Ray mode only)') +parser.add_argument('--num_gpus', + type=int, + default=4, + help='Number of GPUs (Ray mode only)') parser.add_argument('--max_steps', type=int, default=None) -parser.add_argument('--model', type=str, default='ms://Qwen/Qwen2.5-7B-Instruct') +parser.add_argument('--model', + type=str, + default='ms://Qwen/Qwen2.5-7B-Instruct') args = parser.parse_args() # Set mode in environment before importing twinkle os.environ['TWINKLE_MODE'] = args.mode -# CRITICAL: Set CUDA device before any CUDA imports (local mode only) -import torch if args.mode == 'local': LOCAL_RANK = int(os.environ.get('LOCAL_RANK', '0')) torch.cuda.set_device(LOCAL_RANK) -import numpy as np -from peft import LoraConfig -from torch.optim import AdamW -from torch.optim.lr_scheduler import LinearLR - -import twinkle -from twinkle import get_device_placement, get_logger, DeviceMesh, DeviceGroup, Platform -from twinkle.dataloader import DataLoader -from twinkle.dataset import Dataset, DatasetMeta -from twinkle.loss import MegatronCrossEntropyLoss -from twinkle.model import MegatronModel -from twinkle.processor import InputProcessor - logger = get_logger() def create_dataset(): - dataset = Dataset(dataset_meta=DatasetMeta('ms://modelscope/competition_math')) - dataset.set_template('Qwen3Template', model_id='ms://Qwen/Qwen2.5-7B-Instruct') + dataset = Dataset( + dataset_meta=DatasetMeta('ms://modelscope/competition_math')) + dataset.set_template('Qwen3Template', + model_id='ms://Qwen/Qwen2.5-7B-Instruct') dataset.map('CompetitionMathProcessor') dataset.encode(batched=True, load_from_cache_file=False) return dataset @@ -61,24 +72,24 @@ def train(): TP_SIZE = args.tp_size PP_SIZE = args.pp_size CP_SIZE = args.cp_size - + if args.mode == 'local': WORLD_SIZE = int(os.environ.get('WORLD_SIZE', '1')) else: WORLD_SIZE = args.num_gpus - + DP_SIZE = WORLD_SIZE // (TP_SIZE * PP_SIZE * CP_SIZE) - + # Device mesh: Match Megatron's order "tp-cp-ep-dp-pp" from innermost to outermost device_mesh = DeviceMesh( device_type='cuda', mesh=np.arange(WORLD_SIZE).reshape(PP_SIZE, DP_SIZE, CP_SIZE, TP_SIZE), mesh_dim_names=('pp', 'dp', 'cp', 'tp'), ) - + # Device group name - used as remote_group in Ray mode GROUP_NAME = 'model' - + device_group = [ DeviceGroup( name=GROUP_NAME, @@ -86,7 +97,7 @@ def train(): device_type=Platform.get_platform().device_prefix(), ) ] - + twinkle.initialize( mode=args.mode, nproc_per_node=WORLD_SIZE, @@ -94,10 +105,10 @@ def train(): global_device_mesh=device_mesh, lazy_collect=False, ) - + # Use smaller batch size for single GPU to avoid OOM batch_size = 2 if WORLD_SIZE == 1 else 8 - + # In Ray mode, pass remote_group and device_mesh if args.mode == 'ray': dataloader = DataLoader( @@ -129,9 +140,13 @@ def train(): lora_config = LoraConfig(target_modules='all-linear') adapter_name = 'lora' - model.add_adapter_to_model(adapter_name, lora_config, gradient_accumulation_steps=16) + model.add_adapter_to_model(adapter_name, + lora_config, + gradient_accumulation_steps=16) model.set_template('Qwen3Template', adapter_name=adapter_name) - model.set_processor(InputProcessor, padding_side='right', adapter_name=adapter_name) + model.set_processor(InputProcessor, + padding_side='right', + adapter_name=adapter_name) model.set_loss(MegatronCrossEntropyLoss, adapter_name=adapter_name) model.set_optimizer(AdamW, lr=1e-4, adapter_name=adapter_name) model.set_lr_scheduler(LinearLR, adapter_name=adapter_name) @@ -140,7 +155,8 @@ def train(): logger.info(model.get_train_configs(adapter_name=adapter_name)) for step, batch in enumerate(dataloader): - output = model.forward_backward(inputs=batch, adapter_name=adapter_name) + output = model.forward_backward(inputs=batch, + adapter_name=adapter_name) if step % 16 == 0: logger.info(f'Step {step // 16}, loss: {output}') model.clip_grad_norm(1.0, adapter_name=adapter_name) diff --git a/cookbook/megatron/megatron_multi_tenant/client.py b/cookbook/megatron/megatron_multi_tenant/client.py new file mode 100644 index 00000000..9a729e5f --- /dev/null +++ b/cookbook/megatron/megatron_multi_tenant/client.py @@ -0,0 +1,162 @@ +""" +Multi-Tenant Megatron LoRA Training - Client Example. + +Simple training loop using remote multi-tenant server. +Inspired by tinker-cookbook's minimal training scripts. +""" + +import logging +import time +from dataclasses import dataclass +from typing import Any, Dict, Iterator, Optional + +import requests + +logger = logging.getLogger(__name__) + + +@dataclass +class Config: + """Training configuration.""" + server_url: str = "http://localhost:8080" + lora_rank: int = 8 + learning_rate: float = 1e-4 + batch_size: int = 8 + gradient_accumulation_steps: int = 4 + max_grad_norm: float = 1.0 + log_every: int = 10 + + +class TrainingClient: + """ + Simple client for multi-tenant LoRA training. + + Example: + >>> client = TrainingClient(server_url) + >>> client.initialize(lora_rank=8, learning_rate=1e-4) + >>> + >>> for batch in dataloader: + ... result = client.forward_backward(batch) + ... if client.should_step(): + ... client.step() + >>> + >>> client.finalize() + """ + + def __init__(self, server_url: str = "http://localhost:8080"): + self.server_url = server_url.rstrip('/') + self.tenant_id: Optional[str] = None + self._session = requests.Session() + self._accumulated = 0 + self._ga_steps = 1 + + def _post(self, endpoint: str, **kwargs) -> Dict: + """Make POST request.""" + headers = {"X-Tenant-ID": self.tenant_id} if self.tenant_id else {} + resp = self._session.post( + f"{self.server_url}{endpoint}", + headers=headers, + json=kwargs, + timeout=300, + ) + resp.raise_for_status() + return resp.json() + + def initialize( + self, + lora_rank: int = 8, + learning_rate: float = 1e-4, + gradient_accumulation_steps: int = 1, + **kwargs, + ) -> str: + """Initialize tenant on server.""" + result = self._post( + "/initialize", + lora_config={"r": lora_rank, "target_modules": "all-linear"}, + optimizer_kwargs={"lr": learning_rate}, + gradient_accumulation_steps=gradient_accumulation_steps, + **kwargs, + ) + self.tenant_id = result["tenant_id"] + self._ga_steps = gradient_accumulation_steps + logger.info(f"Initialized: {self.tenant_id}") + return self.tenant_id + + def finalize(self): + """Cleanup tenant.""" + if self.tenant_id: + self._post("/finalize") + logger.info(f"Finalized: {self.tenant_id}") + self.tenant_id = None + + def forward_backward(self, inputs: Any) -> Dict: + """Forward + backward pass.""" + result = self._post("/forward_backward", inputs=inputs) + self._accumulated += 1 + return result.get("data", {}) + + def should_step(self) -> bool: + """Check if optimizer step should happen.""" + return self._accumulated >= self._ga_steps + + def step(self): + """Optimizer step.""" + self._post("/finish_grad_sync") + self._post("/clip_grad_norm") + self._post("/step") + self._post("/zero_grad") + self._post("/lr_step") + self._accumulated = 0 + + def __enter__(self): + return self + + def __exit__(self, *args): + self.finalize() + + +def main(config: Config): + """Example training loop.""" + logging.basicConfig(level=logging.INFO) + + # Create client + client = TrainingClient(config.server_url) + + # Initialize + client.initialize( + lora_rank=config.lora_rank, + learning_rate=config.learning_rate, + gradient_accumulation_steps=config.gradient_accumulation_steps, + ) + + try: + # Training loop + for step in range(100): + start = time.time() + + # Create dummy batch (replace with your data loading) + batch = { + "input_ids": list(range(128)), + "attention_mask": [1] * 128, + "labels": list(range(128)), + } + + # Forward + backward + result = client.forward_backward(batch) + + # Optimizer step + if client.should_step(): + client.step() + + if step % config.log_every == 0: + elapsed = time.time() - start + logger.info(f"Step {step}, time: {elapsed:.2f}s") + + logger.info("Training complete!") + + finally: + client.finalize() + + +if __name__ == "__main__": + main(Config()) diff --git a/cookbook/megatron/megatron_multi_tenant/server.py b/cookbook/megatron/megatron_multi_tenant/server.py new file mode 100644 index 00000000..6cfa63e6 --- /dev/null +++ b/cookbook/megatron/megatron_multi_tenant/server.py @@ -0,0 +1,237 @@ +""" +Multi-Tenant Megatron LoRA Training - Server. + +Creates a shared base model and provides APIs for multi-tenant training. +""" + +import argparse +import logging +import threading +import time +from typing import Any, Dict, List, Optional + +import torch +from fastapi import FastAPI, HTTPException, Request +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + + +# ============ Request/Response Models ============ + +class InitializeRequest(BaseModel): + lora_config: Optional[Dict[str, Any]] = None + optimizer_cls: str = "AdamW" + optimizer_kwargs: Optional[Dict[str, Any]] = None + gradient_accumulation_steps: int = 1 + max_grad_norm: float = 1.0 + +class InputsRequest(BaseModel): + inputs: Any + +class TenantResponse(BaseModel): + status: str = "ok" + tenant_id: Optional[str] = None + data: Optional[Any] = None + + +# ============ Server ============ + +class MultiTenantServer: + """Server managing multi-tenant Megatron model.""" + + TIMEOUT = 60 * 30 # 30 min heartbeat timeout + + def __init__(self, model_id: str, tp_size: int = 1): + self.model_id = model_id + self.tp_size = tp_size + self.model = None + self._heartbeats: Dict[str, float] = {} + self._lock = threading.Lock() + + def setup(self): + """Initialize model.""" + from twinkle.megatron.model import ( + MultiTenantMegatronModel, + initialize_megatron_model, + ) + + logger.info(f"Loading model: {self.model_id}") + base_model, config = initialize_megatron_model( + model_id=self.model_id, + tensor_parallel_size=self.tp_size, + ) + + # Freeze base model + for p in base_model.parameters(): + p.requires_grad = False + + self.model = MultiTenantMegatronModel(base_model, config) + logger.info("Server ready") + + # Start heartbeat monitor + threading.Thread(target=self._monitor, daemon=True).start() + + def _monitor(self): + """Cleanup inactive tenants.""" + while True: + time.sleep(60) + now = time.time() + with self._lock: + expired = [t for t, ts in self._heartbeats.items() if now - ts > self.TIMEOUT] + for tid in expired: + logger.warning(f"Tenant {tid} timed out") + try: + self.finalize(tid) + except: + pass + + def _heartbeat(self, tenant_id: str): + with self._lock: + self._heartbeats[tenant_id] = time.time() + + def initialize(self, request: InitializeRequest) -> str: + """Initialize tenant.""" + from peft import LoraConfig + + lora_config = None + if request.lora_config: + lora_config = LoraConfig(**request.lora_config) + + opt_map = {"AdamW": torch.optim.AdamW, "Adam": torch.optim.Adam} + opt_cls = opt_map.get(request.optimizer_cls, torch.optim.AdamW) + + tenant_id = self.model.initialize( + lora_config=lora_config, + optimizer_cls=opt_cls, + optimizer_kwargs=request.optimizer_kwargs, + gradient_accumulation_steps=request.gradient_accumulation_steps, + max_grad_norm=request.max_grad_norm, + ) + + self._heartbeat(tenant_id) + return tenant_id + + def finalize(self, tenant_id: str): + """Finalize tenant.""" + self.model.finalize(tenant_id) + with self._lock: + self._heartbeats.pop(tenant_id, None) + + def forward_backward(self, tenant_id: str, inputs: Any) -> Dict: + """Forward + backward.""" + self._heartbeat(tenant_id) + + with self.model.scope(tenant_id): + output = self.model(inputs) + # Compute loss (simplified - real impl would depend on task) + loss = output.mean() if isinstance(output, torch.Tensor) else torch.tensor(0.0) + self.model.backward(loss) + return {"loss": loss.item()} + + def finish_grad_sync(self, tenant_id: str): + self._heartbeat(tenant_id) + self.model.finish_grad_sync(tenant_id) + + def clip_grad_norm(self, tenant_id: str) -> float: + self._heartbeat(tenant_id) + return self.model.clip_grad_norm(tenant_id=tenant_id).item() + + def step(self, tenant_id: str): + self._heartbeat(tenant_id) + self.model.step(tenant_id) + + def zero_grad(self, tenant_id: str): + self._heartbeat(tenant_id) + self.model.zero_grad(tenant_id) + + def lr_step(self, tenant_id: str): + self._heartbeat(tenant_id) + self.model.lr_step(tenant_id) + + def list_tenants(self) -> List[str]: + return self.model.list_tenants() + + +# ============ FastAPI App ============ + +def create_app(server: MultiTenantServer) -> FastAPI: + """Create FastAPI app.""" + app = FastAPI(title="Multi-Tenant Megatron Server") + + def get_tenant(request: Request) -> str: + tid = request.headers.get("X-Tenant-ID") + if not tid: + raise HTTPException(400, "Missing X-Tenant-ID") + return tid + + @app.post("/initialize", response_model=TenantResponse) + def initialize(body: InitializeRequest): + tid = server.initialize(body) + return TenantResponse(tenant_id=tid) + + @app.post("/finalize", response_model=TenantResponse) + def finalize(request: Request): + server.finalize(get_tenant(request)) + return TenantResponse() + + @app.post("/forward_backward", response_model=TenantResponse) + def forward_backward(request: Request, body: InputsRequest): + data = server.forward_backward(get_tenant(request), body.inputs) + return TenantResponse(data=data) + + @app.post("/finish_grad_sync", response_model=TenantResponse) + def finish_grad_sync(request: Request): + server.finish_grad_sync(get_tenant(request)) + return TenantResponse() + + @app.post("/clip_grad_norm", response_model=TenantResponse) + def clip_grad_norm(request: Request): + norm = server.clip_grad_norm(get_tenant(request)) + return TenantResponse(data=norm) + + @app.post("/step", response_model=TenantResponse) + def step(request: Request): + server.step(get_tenant(request)) + return TenantResponse() + + @app.post("/zero_grad", response_model=TenantResponse) + def zero_grad(request: Request): + server.zero_grad(get_tenant(request)) + return TenantResponse() + + @app.post("/lr_step", response_model=TenantResponse) + def lr_step(request: Request): + server.lr_step(get_tenant(request)) + return TenantResponse() + + @app.get("/tenants") + def tenants(): + return {"tenants": server.list_tenants()} + + @app.get("/health") + def health(): + return {"status": "healthy"} + + return app + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model-id", required=True) + parser.add_argument("--tp", type=int, default=1) + parser.add_argument("--host", default="0.0.0.0") + parser.add_argument("--port", type=int, default=8080) + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO) + + server = MultiTenantServer(args.model_id, args.tp) + server.setup() + + import uvicorn + uvicorn.run(create_app(server), host=args.host, port=args.port) + + +if __name__ == "__main__": + main() diff --git a/cookbook/megatron/moe_lora.py b/cookbook/megatron/moe_lora.py index 37ee3894..c8556cc8 100644 --- a/cookbook/megatron/moe_lora.py +++ b/cookbook/megatron/moe_lora.py @@ -12,49 +12,66 @@ import argparse import os +import numpy as np +# CRITICAL: Set CUDA device before any CUDA imports (local mode only) +import torch +from peft import LoraConfig +from torch.optim import AdamW +from torch.optim.lr_scheduler import LinearLR + +import twinkle +from twinkle import (DeviceGroup, DeviceMesh, Platform, get_device_placement, + get_logger) +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.loss import MegatronCrossEntropyLoss +from twinkle.model import MegatronModel +from twinkle.processor import InputProcessor + # Parse arguments first to determine mode parser = argparse.ArgumentParser() -parser.add_argument('--mode', type=str, default='local', choices=['local', 'ray']) +parser.add_argument('--mode', + type=str, + default='local', + choices=['local', 'ray']) parser.add_argument('--tp_size', type=int, default=2) parser.add_argument('--pp_size', type=int, default=1) parser.add_argument('--cp_size', type=int, default=1) -parser.add_argument('--ep_size', type=int, default=2, help='Expert parallel size') -parser.add_argument('--num_gpus', type=int, default=4, help='Number of GPUs (Ray mode only)') +parser.add_argument('--ep_size', + type=int, + default=2, + help='Expert parallel size') +parser.add_argument('--num_gpus', + type=int, + default=4, + help='Number of GPUs (Ray mode only)') parser.add_argument('--max_steps', type=int, default=5) -parser.add_argument('--model', type=str, default='ms://Qwen/Qwen3-30B-A3B', - help='MoE model path. Default: Qwen3-30B-A3B (128 experts)') -parser.add_argument('--sequence_parallel', action='store_true', default=False, - help='Enable sequence parallel (auto-enabled for MoE with TP > 1)') +parser.add_argument( + '--model', + type=str, + default='ms://Qwen/Qwen3-30B-A3B', + help='MoE model path. Default: Qwen3-30B-A3B (128 experts)') +parser.add_argument( + '--sequence_parallel', + action='store_true', + default=False, + help='Enable sequence parallel (auto-enabled for MoE with TP > 1)') args = parser.parse_args() # Set mode in environment before importing twinkle os.environ['TWINKLE_MODE'] = args.mode -# CRITICAL: Set CUDA device before any CUDA imports (local mode only) -import torch if args.mode == 'local': LOCAL_RANK = int(os.environ.get('LOCAL_RANK', '0')) torch.cuda.set_device(LOCAL_RANK) -import numpy as np -from peft import LoraConfig -from torch.optim import AdamW -from torch.optim.lr_scheduler import LinearLR - -import twinkle -from twinkle import get_device_placement, get_logger, DeviceMesh, DeviceGroup, Platform -from twinkle.dataloader import DataLoader -from twinkle.dataset import Dataset, DatasetMeta -from twinkle.loss import MegatronCrossEntropyLoss -from twinkle.model import MegatronModel -from twinkle.processor import InputProcessor - logger = get_logger() def create_dataset(): """Create dataset for MoE training.""" - dataset = Dataset(dataset_meta=DatasetMeta('ms://modelscope/competition_math')) + dataset = Dataset( + dataset_meta=DatasetMeta('ms://modelscope/competition_math')) # Use Qwen3 template for MoE model dataset.set_template('Qwen3Template', model_id=args.model) dataset.map('CompetitionMathProcessor') @@ -68,36 +85,38 @@ def train(): PP_SIZE = args.pp_size CP_SIZE = args.cp_size EP_SIZE = args.ep_size - + if args.mode == 'local': WORLD_SIZE = int(os.environ.get('WORLD_SIZE', '1')) else: WORLD_SIZE = args.num_gpus - + # For MoE with EP: Total parallelism = TP * PP * CP * EP * DP # EP is placed between CP and DP in Megatron's order DP_SIZE = WORLD_SIZE // (TP_SIZE * PP_SIZE * CP_SIZE * EP_SIZE) - + if DP_SIZE < 1: raise ValueError( - f"Not enough GPUs ({WORLD_SIZE}) for parallelism config: " - f"TP={TP_SIZE}, PP={PP_SIZE}, CP={CP_SIZE}, EP={EP_SIZE}. " - f"Required: {TP_SIZE * PP_SIZE * CP_SIZE * EP_SIZE}" - ) - - logger.info(f"Parallelism config: TP={TP_SIZE}, PP={PP_SIZE}, CP={CP_SIZE}, EP={EP_SIZE}, DP={DP_SIZE}") - + f'Not enough GPUs ({WORLD_SIZE}) for parallelism config: ' + f'TP={TP_SIZE}, PP={PP_SIZE}, CP={CP_SIZE}, EP={EP_SIZE}. ' + f'Required: {TP_SIZE * PP_SIZE * CP_SIZE * EP_SIZE}') + + logger.info( + f'Parallelism config: TP={TP_SIZE}, PP={PP_SIZE}, CP={CP_SIZE}, EP={EP_SIZE}, DP={DP_SIZE}' + ) + # Device mesh: Match Megatron's order "tp-cp-ep-dp-pp" from innermost to outermost # Shape: (PP, DP, EP, CP, TP) device_mesh = DeviceMesh( device_type='cuda', - mesh=np.arange(WORLD_SIZE).reshape(PP_SIZE, DP_SIZE, EP_SIZE, CP_SIZE, TP_SIZE), + mesh=np.arange(WORLD_SIZE).reshape(PP_SIZE, DP_SIZE, EP_SIZE, CP_SIZE, + TP_SIZE), mesh_dim_names=('pp', 'dp', 'ep', 'cp', 'tp'), ) - + # Device group name - used as remote_group in Ray mode GROUP_NAME = 'model' - + device_group = [ DeviceGroup( name=GROUP_NAME, @@ -105,7 +124,7 @@ def train(): device_type=Platform.get_platform().device_prefix(), ) ] - + twinkle.initialize( mode=args.mode, nproc_per_node=WORLD_SIZE, @@ -113,10 +132,10 @@ def train(): global_device_mesh=device_mesh, lazy_collect=False, ) - + # Smaller batch size for MoE models (larger memory footprint) batch_size = 2 - + # In Ray mode, pass remote_group and device_mesh if args.mode == 'ray': dataloader = DataLoader( @@ -158,9 +177,13 @@ def train(): lora_dropout=0.0, ) adapter_name = 'lora' - model.add_adapter_to_model(adapter_name, lora_config, gradient_accumulation_steps=16) + model.add_adapter_to_model(adapter_name, + lora_config, + gradient_accumulation_steps=16) model.set_template('Qwen3Template', adapter_name=adapter_name) - model.set_processor(InputProcessor, padding_side='right', adapter_name=adapter_name) + model.set_processor(InputProcessor, + padding_side='right', + adapter_name=adapter_name) model.set_loss(MegatronCrossEntropyLoss, adapter_name=adapter_name) model.set_optimizer(AdamW, lr=1e-4, adapter_name=adapter_name) model.set_lr_scheduler(LinearLR, adapter_name=adapter_name) @@ -169,7 +192,8 @@ def train(): logger.info(model.get_train_configs(adapter_name=adapter_name)) for step, batch in enumerate(dataloader): - output = model.forward_backward(inputs=batch, adapter_name=adapter_name) + output = model.forward_backward(inputs=batch, + adapter_name=adapter_name) if step % 16 == 0: logger.info(f'Step {step // 16}, loss: {output}') model.clip_grad_norm(1.0, adapter_name=adapter_name) diff --git a/cookbook/megatron_multi_tenant/server.py b/cookbook/megatron_multi_tenant/server.py new file mode 100644 index 00000000..45ecd925 --- /dev/null +++ b/cookbook/megatron_multi_tenant/server.py @@ -0,0 +1,239 @@ +""" +Multi-Tenant Megatron LoRA Training - Server. + +Creates a shared base model and provides APIs for multi-tenant training. +""" + +import argparse +import logging +import threading +import time +from typing import Any, Dict, List, Optional + +import torch +from fastapi import FastAPI, HTTPException, Request +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + + +# ============ Request/Response Models ============ + +class InitializeRequest(BaseModel): + lora_config: Optional[Dict[str, Any]] = None + optimizer_cls: str = "AdamW" + optimizer_kwargs: Optional[Dict[str, Any]] = None + gradient_accumulation_steps: int = 1 + max_grad_norm: float = 1.0 + +class InputsRequest(BaseModel): + inputs: Any + +class TenantResponse(BaseModel): + status: str = "ok" + tenant_id: Optional[str] = None + data: Optional[Any] = None + + +# ============ Server ============ + +class MultiTenantServer: + """Server managing multi-tenant Megatron model.""" + + TIMEOUT = 60 * 30 # 30 min heartbeat timeout + + def __init__(self, model_id: str, tp_size: int = 1): + self.model_id = model_id + self.tp_size = tp_size + self.model = None + self._heartbeats: Dict[str, float] = {} + self._lock = threading.Lock() + + def setup(self): + """Initialize model.""" + from twinkle.megatron.model import ( + MultiTenantMegatronModel, + initialize_megatron_model, + ) + + logger.info(f"Loading model: {self.model_id}") + base_model, config = initialize_megatron_model( + model_id=self.model_id, + tensor_parallel_size=self.tp_size, + ) + + # Freeze base model + for p in base_model.parameters(): + p.requires_grad = False + + self.model = MultiTenantMegatronModel(base_model, config) + logger.info("Server ready") + + # Start heartbeat monitor + threading.Thread(target=self._monitor, daemon=True).start() + + def _monitor(self): + """Cleanup inactive tenants.""" + while True: + time.sleep(60) + now = time.time() + with self._lock: + expired = [t for t, ts in self._heartbeats.items() if now - ts > self.TIMEOUT] + for tid in expired: + logger.warning(f"Tenant {tid} timed out") + try: + self.finalize(tid) + except: + pass + + def _heartbeat(self, tenant_id: str): + with self._lock: + self._heartbeats[tenant_id] = time.time() + + def initialize(self, request: InitializeRequest) -> str: + """Initialize tenant.""" + from peft import LoraConfig + + lora_config = None + if request.lora_config: + lora_config = LoraConfig(**request.lora_config) + + opt_map = {"AdamW": torch.optim.AdamW, "Adam": torch.optim.Adam} + opt_cls = opt_map.get(request.optimizer_cls, torch.optim.AdamW) + + tenant_id = self.model.initialize( + lora_config=lora_config, + optimizer_cls=opt_cls, + optimizer_kwargs=request.optimizer_kwargs, + gradient_accumulation_steps=request.gradient_accumulation_steps, + max_grad_norm=request.max_grad_norm, + ) + + self._heartbeat(tenant_id) + return tenant_id + + def finalize(self, tenant_id: str): + """Finalize tenant.""" + self.model.finalize(tenant_id) + with self._lock: + self._heartbeats.pop(tenant_id, None) + + def forward_backward(self, tenant_id: str, inputs: Any) -> Dict: + """Forward + backward.""" + self._heartbeat(tenant_id) + + with self.model.scope(tenant_id): + output = self.model(inputs) + # Compute loss (simplified - real impl would depend on task) + loss = output.mean() if isinstance(output, torch.Tensor) else torch.tensor(0.0) + self.model.backward(loss) + return {"loss": loss.item()} + + def finish_grad_sync(self, tenant_id: str): + self._heartbeat(tenant_id) + self.model.finish_grad_sync(tenant_id) + + def clip_grad_norm(self, tenant_id: str) -> float: + self._heartbeat(tenant_id) + return self.model.clip_grad_norm(tenant_id=tenant_id).item() + + def step(self, tenant_id: str): + self._heartbeat(tenant_id) + self.model.step(tenant_id) + + def zero_grad(self, tenant_id: str): + self._heartbeat(tenant_id) + self.model.zero_grad(tenant_id) + + def lr_step(self, tenant_id: str): + self._heartbeat(tenant_id) + self.model.lr_step(tenant_id) + + def tenant_count(self) -> int: + """Get number of active tenants (does not expose tenant IDs).""" + return self.model.tenant_count() + + +# ============ FastAPI App ============ + +def create_app(server: MultiTenantServer) -> FastAPI: + """Create FastAPI app.""" + app = FastAPI(title="Multi-Tenant Megatron Server") + + def get_tenant(request: Request) -> str: + tid = request.headers.get("X-Tenant-ID") + if not tid: + raise HTTPException(400, "Missing X-Tenant-ID") + return tid + + @app.post("/initialize", response_model=TenantResponse) + def initialize(body: InitializeRequest): + tid = server.initialize(body) + return TenantResponse(tenant_id=tid) + + @app.post("/finalize", response_model=TenantResponse) + def finalize(request: Request): + server.finalize(get_tenant(request)) + return TenantResponse() + + @app.post("/forward_backward", response_model=TenantResponse) + def forward_backward(request: Request, body: InputsRequest): + data = server.forward_backward(get_tenant(request), body.inputs) + return TenantResponse(data=data) + + @app.post("/finish_grad_sync", response_model=TenantResponse) + def finish_grad_sync(request: Request): + server.finish_grad_sync(get_tenant(request)) + return TenantResponse() + + @app.post("/clip_grad_norm", response_model=TenantResponse) + def clip_grad_norm(request: Request): + norm = server.clip_grad_norm(get_tenant(request)) + return TenantResponse(data=norm) + + @app.post("/step", response_model=TenantResponse) + def step(request: Request): + server.step(get_tenant(request)) + return TenantResponse() + + @app.post("/zero_grad", response_model=TenantResponse) + def zero_grad(request: Request): + server.zero_grad(get_tenant(request)) + return TenantResponse() + + @app.post("/lr_step", response_model=TenantResponse) + def lr_step(request: Request): + server.lr_step(get_tenant(request)) + return TenantResponse() + + @app.get("/stats") + def stats(): + """Server statistics (does not expose tenant IDs for privacy).""" + return {"tenant_count": server.tenant_count()} + + @app.get("/health") + def health(): + return {"status": "healthy"} + + return app + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model-id", required=True) + parser.add_argument("--tp", type=int, default=1) + parser.add_argument("--host", default="0.0.0.0") + parser.add_argument("--port", type=int, default=8080) + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO) + + server = MultiTenantServer(args.model_id, args.tp) + server.setup() + + import uvicorn + uvicorn.run(create_app(server), host=args.host, port=args.port) + + +if __name__ == "__main__": + main() diff --git a/src/twinkle/megatron/__init__.py b/src/twinkle/megatron/__init__.py index b91ea610..c8b0c3d9 100644 --- a/src/twinkle/megatron/__init__.py +++ b/src/twinkle/megatron/__init__.py @@ -4,56 +4,22 @@ This module provides independent implementation for Megatron support, """ +from .model import (BridgeConfig, LazyTensor, MegatronModelInitializer, + Qwen3ModelMeta, SafetensorLoader, StreamingSafetensorSaver, + TwinkleBridgeAdapter, TwinkleGPTBridge, + create_megatron_args, get_model_default_config, + initialize_megatron_model, is_last_rank, + load_hf_weights_to_megatron, mock_megatron_args, + restore_megatron_args, set_megatron_args) from .tuners import LoraParallelLinear, dispatch_megatron -from .utils import ( - # Layer finding - find_all_linears, - find_router, - find_embedding, - get_target_modules, - set_linear_is_expert, - # Model preparation - prepare_mcore_model, - prepare_lora_model, - # Config conversion - convert_hf_config, - # Utilities - get_model_parameter_info, - get_padding_to, - patch_deepcopy, - tuners_sharded_state_dict, - forward_step_helper, - deep_getattr, - # Multi-tenant support - TenantProcessGroupManager, - get_tenant_manager, - # Training state - MegatronTrainerState, -) -from .model import ( - # Bridge classes - TwinkleBridgeAdapter, - TwinkleGPTBridge, - BridgeConfig, - SafetensorLoader, - StreamingSafetensorSaver, - LazyTensor, - # Helper functions - load_hf_weights_to_megatron, - is_last_rank, - deep_getattr as bridge_deep_getattr, # Avoid conflict with utils.deep_getattr - # Legacy compatibility - create_megatron_args, - set_megatron_args, - restore_megatron_args, - mock_megatron_args, - # Initializer - MegatronModelInitializer, - initialize_megatron_model, - # Qwen3 support - Qwen3ModelMeta, - get_model_default_config, -) + +from .model import deep_getattr as bridge_deep_getattr # Bridge classes; Helper functions; Avoid conflict with utils.deep_getattr; Legacy compatibility; Initializer; Qwen3 support +from .utils import ( # Layer finding; Model preparation; Config conversion; Utilities; Multi-tenant support; Training state + MegatronTrainerState, TenantProcessGroupManager, convert_hf_config, + deep_getattr, find_all_linears, find_embedding, find_router, + forward_step_helper, get_model_parameter_info, get_padding_to, + get_target_modules, get_tenant_manager, patch_deepcopy, prepare_lora_model, + prepare_mcore_model, set_linear_is_expert, tuners_sharded_state_dict) __all__ = [ # Tuners diff --git a/src/twinkle/megatron/distributed/__init__.py b/src/twinkle/megatron/distributed/__init__.py index 76a7c8a0..a1defbab 100644 --- a/src/twinkle/megatron/distributed/__init__.py +++ b/src/twinkle/megatron/distributed/__init__.py @@ -1,16 +1,31 @@ # Copyright (c) twinkle authors. All rights reserved. +""" +Distributed training utilities for multi-tenant Megatron LoRA. +Core components: +- tenant_context: ContextVar-based tenant management +- tenant_manager: Tenant lifecycle (adapters, optimizers) +- multi_tenant_ddp: Per-tenant gradient buffers and sync +""" -from .multi_tenant_ddp import ( - MultiTenantLoRADDP, - TenantContext, - TenantGradientManager, - create_multi_tenant_ddp, -) +from .multi_tenant_ddp import MultiTenantLoRADDP, TenantDDPState +from .tenant_context import (TenantInfo, generate_tenant_id, + get_current_tenant, require_tenant, + set_current_tenant, tenant_scope) +from .tenant_manager import TenantManager, TenantState __all__ = [ + # Context + 'get_current_tenant', + 'set_current_tenant', + 'require_tenant', + 'tenant_scope', + 'generate_tenant_id', + 'TenantInfo', + # Manager + 'TenantManager', + 'TenantState', + # DDP 'MultiTenantLoRADDP', - 'TenantContext', - 'TenantGradientManager', - 'create_multi_tenant_ddp', + 'TenantDDPState', ] diff --git a/src/twinkle/megatron/distributed/multi_tenant_ddp.py b/src/twinkle/megatron/distributed/multi_tenant_ddp.py index 9f84f230..e046555a 100644 --- a/src/twinkle/megatron/distributed/multi_tenant_ddp.py +++ b/src/twinkle/megatron/distributed/multi_tenant_ddp.py @@ -2,341 +2,397 @@ """ Multi-Tenant LoRA DDP for Megatron models. -This module provides a minimal, maintainable DDP solution for multi-tenant LoRA training: -1. Inherits from Megatron's DistributedDataParallel to maximize code reuse -2. Uses MultiAdapter's ContextVar mechanism for tenant isolation -3. Supports per-tenant process groups for gradient synchronization +This module provides a DDP implementation for multi-tenant LoRA training, +inheriting from Megatron's DistributedDataParallel. -Key insight: Megatron DDP already only creates buffers for requires_grad=True params, -so we just need to control which params are trainable per-tenant. +Key Design: +1. Inherits from MegatronDDP for code reuse +2. Overrides buffer/bucket creation to be per-tenant +3. Uses ContextVar for automatic tenant resolution +4. Tenant lifecycle managed by TenantManager (separate concern) """ -import contextvars import logging -from typing import Dict, List, Optional, Set +from contextlib import contextmanager +from dataclasses import dataclass, field +from typing import Dict, List, Optional +import torch import torch.distributed as dist import torch.nn as nn +from .tenant_context import get_current_tenant, require_tenant, tenant_scope + logger = logging.getLogger(__name__) try: from megatron.core import parallel_state as mpu from megatron.core.distributed import DistributedDataParallel as MegatronDDP from megatron.core.distributed.distributed_data_parallel_config import DistributedDataParallelConfig + from megatron.core.distributed.param_and_grad_buffer import ( + _ParamAndGradBuffer, + partition_buckets, + ) + from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.transformer_config import TransformerConfig MEGATRON_AVAILABLE = True except ImportError: MEGATRON_AVAILABLE = False - MegatronDDP = object + # Fallback for type hints + class MegatronDDP(nn.Module): + pass -class TenantContext: - """ - Thread/coroutine-safe tenant context using ContextVar. - - This integrates with MultiAdapter's ContextVar mechanism to ensure - that each request/coroutine operates on the correct tenant's LoRA weights. - """ - - _current_tenant: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar( - 'current_tenant', default=None - ) - - @classmethod - def get_current_tenant(cls) -> Optional[str]: - return cls._current_tenant.get() - - @classmethod - def set_current_tenant(cls, tenant_id: str): - cls._current_tenant.set(tenant_id) - - @classmethod - def reset_tenant(cls): - cls._current_tenant.set(None) - - -class TenantGradientManager: - """ - Manages per-tenant gradient buffers and communication groups. - - This is a lightweight wrapper that doesn't duplicate Megatron DDP logic, - but instead coordinates gradient sync across tenants. - """ - - def __init__(self): - self.tenant_params: Dict[str, Set[nn.Parameter]] = {} - self.tenant_process_groups: Dict[str, dist.ProcessGroup] = {} - self.tenant_param_names: Dict[str, Dict[nn.Parameter, str]] = {} - - def register_tenant( - self, - tenant_id: str, - params: List[nn.Parameter], - param_names: Dict[nn.Parameter, str], - process_group: Optional[dist.ProcessGroup] = None, - ): - """ - Register a tenant with its LoRA parameters. - - Args: - tenant_id: Unique tenant identifier. - params: List of LoRA parameters for this tenant. - param_names: Mapping from param to name for debugging. - process_group: Optional custom process group for this tenant. - """ - self.tenant_params[tenant_id] = set(params) - self.tenant_param_names[tenant_id] = param_names - - if process_group is not None: - self.tenant_process_groups[tenant_id] = process_group - else: - # Use default DP group - self.tenant_process_groups[tenant_id] = mpu.get_data_parallel_group( - with_context_parallel=True - ) - - logger.info( - f"Registered tenant '{tenant_id}' with {len(params)} parameters, " - f"process group size: {self.tenant_process_groups[tenant_id].size()}" - ) - - def unregister_tenant(self, tenant_id: str): - """Remove a tenant and its associated resources.""" - self.tenant_params.pop(tenant_id, None) - self.tenant_param_names.pop(tenant_id, None) - self.tenant_process_groups.pop(tenant_id, None) - logger.info(f"Unregistered tenant '{tenant_id}'") - - def get_tenant_params(self, tenant_id: str) -> Set[nn.Parameter]: - return self.tenant_params.get(tenant_id, set()) - - def get_tenant_process_group(self, tenant_id: str) -> Optional[dist.ProcessGroup]: - return self.tenant_process_groups.get(tenant_id) - - -class MultiTenantLoRADDP(MegatronDDP if MEGATRON_AVAILABLE else object): + +@dataclass +class TenantDDPState: + """Per-tenant DDP state: buffers, bucket groups, hooks.""" + tenant_id: str + params: List[nn.Parameter] = field(default_factory=list) + buffers: List = field(default_factory=list) + bucket_groups: List = field(default_factory=list) + param_to_bucket_group: Dict[nn.Parameter, + object] = field(default_factory=dict) + grad_accs: List = field(default_factory=list) + process_group: Optional[dist.ProcessGroup] = None + + +class MultiTenantLoRADDP(MegatronDDP): """ - Multi-Tenant LoRA DDP wrapper that extends Megatron's DDP. - - Design principles: - 1. **Minimal override**: Only override what's necessary for multi-tenant support - 2. **Reuse Megatron DDP**: All gradient buffer management, bucketing, and - communication overlap logic is inherited from Megatron DDP - 3. **ContextVar integration**: Uses TenantContext for thread-safe tenant switching - 4. **Lazy buffer creation**: Buffers are created per-tenant on first use - - Key insight: Instead of creating a separate DDP per tenant, we: - - Keep one DDP instance with all LoRA parameters - - Use ContextVar to track current tenant - - Filter gradient sync to only current tenant's params - - Example: - >>> # Create multi-tenant DDP + Multi-Tenant LoRA DDP inheriting from MegatronDDP. + + This class extends MegatronDDP to support per-tenant gradient buffers + and communication. The key difference is that instead of creating + buffers for all parameters at init, buffers are created dynamically + for each tenant. + + Comparison with MegatronDDP: + - MegatronDDP: Creates buffers for all requires_grad=True params at __init__ + - MultiTenantLoRADDP: Creates buffers per-tenant when add_tenant is called + + Usage: + >>> # Create with frozen base model (no trainable params yet) >>> ddp = MultiTenantLoRADDP(config, ddp_config, model) - >>> - >>> # Register tenants - >>> ddp.register_tenant('tenant_a', tenant_a_params) - >>> ddp.register_tenant('tenant_b', tenant_b_params) - >>> - >>> # Training loop with tenant isolation - >>> TenantContext.set_current_tenant('tenant_a') - >>> output = ddp(input) # Uses tenant_a's LoRA - >>> loss.backward() - >>> ddp.finish_grad_sync() # Only syncs tenant_a's grads + >>> + >>> # Add tenant (creates buffers for their LoRA params) + >>> ddp.add_tenant('tenant_a', params_a, process_group_a) + >>> + >>> # Training uses current tenant context + >>> with tenant_scope('tenant_a'): + ... ddp.zero_grad_buffer() # Zeros tenant_a's buffers + ... output = ddp(input) + ... loss.backward() + ... ddp.finish_grad_sync() # Syncs tenant_a's gradients + >>> + >>> # Remove tenant + >>> ddp.remove_tenant('tenant_a') """ - - # Default patterns to identify LoRA parameters - DEFAULT_LORA_PATTERNS = {'lora_A', 'lora_B', 'lora_'} - def __init__( self, config: 'TransformerConfig', ddp_config: 'DistributedDataParallelConfig', module: nn.Module, disable_bucketing: bool = False, - lora_param_patterns: Optional[Set[str]] = None, - **kwargs, + pg_collection: Optional['ProcessGroupCollection'] = None, ): """ Initialize MultiTenantLoRADDP. - - This calls the parent Megatron DDP __init__ which will: - 1. Create gradient buffers for all requires_grad=True params - 2. Set up backward hooks - 3. Initialize bucket groups - - We then add multi-tenant management on top. + + Unlike MegatronDDP, this does NOT create buffers at init. + Buffers are created per-tenant via add_tenant(). + + Args: + config: Transformer config. + ddp_config: DDP config. + module: Model (base model should be frozen). + disable_bucketing: Disable bucketing. + pg_collection: Process group collection. """ if not MEGATRON_AVAILABLE: - raise ImportError("Megatron-Core is required") - - self.lora_param_patterns = lora_param_patterns or self.DEFAULT_LORA_PATTERNS - self._tenant_manager = TenantGradientManager() - - # Pre-identify LoRA parameters before parent init - # This helps with debugging and tenant registration - self._lora_params: Dict[str, nn.Parameter] = {} + raise ImportError('Megatron-Core is required') + + # Skip MegatronDDP's buffer creation by temporarily setting all params to not require grad + original_requires_grad = {} for name, param in module.named_parameters(): - if param.requires_grad and self._is_lora_param(name): - self._lora_params[name] = param - - logger.info(f"Identified {len(self._lora_params)} LoRA parameters") - - # Call parent Megatron DDP init - # This creates buffers for all requires_grad=True params + original_requires_grad[name] = param.requires_grad + param.requires_grad = False + + # Call parent init (will create empty buffers since no params require grad) super().__init__( config=config, ddp_config=ddp_config, module=module, disable_bucketing=disable_bucketing, - **kwargs, - ) - - logger.info( - f"MultiTenantLoRADDP initialized with {len(self.params_with_grad)} " - f"trainable parameters, {len(self.bucket_groups)} bucket groups" + pg_collection=pg_collection, ) - - def _is_lora_param(self, name: str) -> bool: - """Check if parameter name matches LoRA patterns.""" - for pattern in self.lora_param_patterns: - if pattern in name: - return True - return False - - def register_tenant( + + # Restore requires_grad + for name, param in module.named_parameters(): + param.requires_grad = original_requires_grad[name] + + # Per-tenant state + self._tenant_states: Dict[str, TenantDDPState] = {} + + logger.info('MultiTenantLoRADDP initialized (no buffers yet)') + + def add_tenant( self, tenant_id: str, - adapter_name: Optional[str] = None, + params: List[nn.Parameter], process_group: Optional[dist.ProcessGroup] = None, + param_names: Optional[Dict[nn.Parameter, str]] = None, ): """ - Register a tenant for multi-tenant training. - + Add a tenant with their gradient buffers. + + This creates per-tenant buffers and hooks, similar to what + MegatronDDP.__init__ does but scoped to this tenant. + Args: - tenant_id: Unique tenant identifier. - adapter_name: PEFT adapter name (if different from tenant_id). - process_group: Custom process group for gradient sync. + tenant_id: Unique tenant ID. + params: Trainable parameters for this tenant. + process_group: Process group for gradient sync. + param_names: Param to name mapping for debugging. """ - adapter_name = adapter_name or tenant_id - - # Find parameters belonging to this adapter - tenant_params = [] - param_names = {} - - for name, param in self._lora_params.items(): - # Match adapter name in parameter path - # e.g., "model.layers.0.self_attn.q_proj.lora_A.tenant_a.weight" - if f'.{adapter_name}.' in name or name.endswith(f'.{adapter_name}'): - tenant_params.append(param) - param_names[param] = name - - if not tenant_params: - # If no adapter-specific match, assume all LoRA params belong to this tenant - # This handles single-tenant scenarios - logger.warning( - f"No adapter-specific params found for '{adapter_name}', " - f"registering all {len(self._lora_params)} LoRA params" - ) - tenant_params = list(self._lora_params.values()) - param_names = {v: k for k, v in self._lora_params.items()} - - self._tenant_manager.register_tenant( + if tenant_id in self._tenant_states: + raise ValueError(f"Tenant '{tenant_id}' already exists") + + if not params: + raise ValueError('No parameters provided') + + process_group = process_group or self.intra_dp_cp_group + param_names = param_names or {} + + # Build param_names if not provided + if not param_names: + for name, param in self.module.named_parameters(): + if param in params: + param_names[param] = name + + # Create tenant state + state = TenantDDPState( tenant_id=tenant_id, - params=tenant_params, - param_names=param_names, + params=params, process_group=process_group, ) - - def unregister_tenant(self, tenant_id: str): - """Remove a tenant.""" - self._tenant_manager.unregister_tenant(tenant_id) - - def set_current_tenant(self, tenant_id: str): - """ - Set the current tenant for subsequent operations. - - This should be called before forward/backward to ensure - correct LoRA adapter is used. - """ - TenantContext.set_current_tenant(tenant_id) - - def get_current_tenant(self) -> Optional[str]: - """Get the current tenant ID.""" - return TenantContext.get_current_tenant() - - def finish_grad_sync_for_tenant(self, tenant_id: Optional[str] = None): - """ - Finish gradient sync for a specific tenant. - - If tenant_id is None, uses current tenant from context. - If no tenant is set, falls back to syncing all params (parent behavior). - """ - tenant_id = tenant_id or TenantContext.get_current_tenant() - - if tenant_id is None: - # No tenant specified, use default behavior - super().finish_grad_sync() - return - - # Get tenant's process group - pg = self._tenant_manager.get_tenant_process_group(tenant_id) - if pg is None: - logger.warning(f"Tenant '{tenant_id}' not registered, using default sync") - super().finish_grad_sync() - return - - # For now, use parent's finish_grad_sync - # In a more advanced implementation, we could filter to only - # sync the tenant's parameters, but Megatron's bucket design - # makes this complex - super().finish_grad_sync() - - def get_tenant_param_count(self, tenant_id: str) -> int: - """Get number of parameters for a tenant.""" - return len(self._tenant_manager.get_tenant_params(tenant_id)) - - def get_tenant_param_numel(self, tenant_id: str) -> int: - """Get total number of elements in tenant's parameters.""" - return sum(p.numel() for p in self._tenant_manager.get_tenant_params(tenant_id)) - - -def create_multi_tenant_ddp( - model: nn.Module, - config: 'TransformerConfig', - ddp_config: Optional['DistributedDataParallelConfig'] = None, - lora_param_patterns: Optional[Set[str]] = None, - overlap_grad_reduce: bool = True, - bucket_size: Optional[int] = None, -) -> MultiTenantLoRADDP: - """ - Factory function to create a MultiTenantLoRADDP wrapper. - - Args: - model: Model containing LoRA layers. - config: Transformer configuration. - ddp_config: DDP configuration. If None, creates default config. - lora_param_patterns: Patterns to identify LoRA parameters. - overlap_grad_reduce: Enable communication-computation overlap. - bucket_size: Size of gradient buckets. - - Returns: - MultiTenantLoRADDP wrapper. - """ - if not MEGATRON_AVAILABLE: - raise ImportError("Megatron-Core is required") - - if ddp_config is None: - ddp_config = DistributedDataParallelConfig( - overlap_grad_reduce=overlap_grad_reduce, - use_distributed_optimizer=False, - bucket_size=bucket_size, + + # Initialize grad flags + for param in params: + param.grad_added_to_main_grad = False + + # Create buffers + self._create_tenant_buffers(state, param_names) + + # Register hooks + self._register_tenant_hooks(state) + + self._tenant_states[tenant_id] = state + + logger.info(f"Added tenant '{tenant_id}' with {len(params)} params, " + f'{len(state.bucket_groups)} bucket groups') + + def _create_tenant_buffers( + self, + state: TenantDDPState, + param_names: Dict[nn.Parameter, str], + ): + """Create gradient buffers for a tenant.""" + # Group by dtype + param_and_grad_dtype_to_params = {} + param_and_grad_dtype_to_indices = {} + + for param in state.params: + param_dtype = param.dtype + grad_dtype = torch.float if self.ddp_config.grad_reduce_in_fp32 else param.dtype + + key = (param_dtype, grad_dtype) + if key not in param_and_grad_dtype_to_params: + param_and_grad_dtype_to_params[key] = [] + param_and_grad_dtype_to_indices[key] = [] + + param_and_grad_dtype_to_params[key].append(param) + param_and_grad_dtype_to_indices[key].append( + len(param_and_grad_dtype_to_params[key]) - 1) + + # Calculate gradient scaling + if self.config.calculate_per_token_loss: + gradient_scaling_factor = 1.0 + elif self.ddp_config.average_in_collective: + gradient_scaling_factor = 1.0 + else: + gradient_scaling_factor = 1.0 / state.process_group.size() + + # ProcessGroupCollection for buffer creation + pg_collection = ProcessGroupCollection() + pg_collection.tp = self.tp_group + pg_collection.dp_cp = state.process_group + + # Create buffers + for (param_dtype, + grad_dtype), params in param_and_grad_dtype_to_params.items(): + indices = param_and_grad_dtype_to_indices[(param_dtype, + grad_dtype)] + + buffer = _ParamAndGradBuffer( + self.ddp_config, + param_dtype, + grad_dtype, + params, + state.process_group, + self.bucket_size, + param_names, + gradient_scaling_factor, + indices, + getattr(self.ddp_config, 'nccl_ub', False), + pg_collection, + ) + state.buffers.append(buffer) + + # Create bucket groups + state.bucket_groups = partition_buckets( + state.buffers, + force_single_bucket_group=(self.bucket_size is None), ) - - return MultiTenantLoRADDP( - config=config, - ddp_config=ddp_config, - module=model, - lora_param_patterns=lora_param_patterns, - ) + + # Build param to bucket group mapping + for bucket_group in state.bucket_groups: + for bucket in bucket_group.buckets: + for param in bucket.params_list: + state.param_to_bucket_group[param] = bucket_group + + def _register_tenant_hooks(self, state: TenantDDPState): + """Register backward hooks for a tenant.""" + for param in state.params: + if param not in state.param_to_bucket_group: + continue + + param_tmp = param.expand_as(param) + grad_acc = param_tmp.grad_fn.next_functions[0][0] + grad_acc.register_hook( + self._make_tenant_backward_hook(param, state)) + state.grad_accs.append(grad_acc) + + def _make_tenant_backward_hook(self, param: nn.Parameter, + state: TenantDDPState): + """Create backward hook for a tenant's parameter.""" + def hook(*unused): + if param in state.param_to_bucket_group: + if param.grad is not None and not param.grad_added_to_main_grad: + param.main_grad.add_(param.grad.data) + param.grad = None + + if self.ddp_config.overlap_grad_reduce: + bucket_group = state.param_to_bucket_group[param] + if bucket_group.is_last_microbatch: + bucket_group.register_grad_ready(param) + + return hook + + def remove_tenant(self, tenant_id: str): + """Remove a tenant and cleanup their resources.""" + if tenant_id not in self._tenant_states: + raise KeyError(f"Tenant '{tenant_id}' not found") + + state = self._tenant_states.pop(tenant_id) + + # Clear hooks + state.grad_accs.clear() + + # Clear buffers + state.buffers.clear() + state.bucket_groups.clear() + state.param_to_bucket_group.clear() + + # Clear param attributes + for param in state.params: + if hasattr(param, 'main_grad'): + delattr(param, 'main_grad') + if hasattr(param, 'grad_added_to_main_grad'): + delattr(param, 'grad_added_to_main_grad') + + logger.info(f"Removed tenant '{tenant_id}'") + + def _get_tenant_state(self, + tenant_id: Optional[str] = None) -> TenantDDPState: + """Get state for tenant (uses context if not specified).""" + tenant_id = tenant_id or require_tenant() + if tenant_id not in self._tenant_states: + raise KeyError(f"Tenant '{tenant_id}' not registered") + return self._tenant_states[tenant_id] + + # ========== Override MegatronDDP methods to be tenant-aware ========== + + @contextmanager + def no_sync(self, tenant_id: Optional[str] = None): + """Disable gradient sync for a tenant.""" + state = self._get_tenant_state(tenant_id) + for bucket_group in state.bucket_groups: + bucket_group.is_last_microbatch = False + try: + yield + finally: + for bucket_group in state.bucket_groups: + bucket_group.is_last_microbatch = True + + def start_grad_sync(self, tenant_id: Optional[str] = None): + """Start gradient sync for a tenant.""" + state = self._get_tenant_state(tenant_id) + for bucket_group in state.bucket_groups: + bucket_group.start_grad_sync() + + def finish_grad_sync(self, tenant_id: Optional[str] = None): + """Finish gradient sync for a tenant.""" + state = self._get_tenant_state(tenant_id) + for bucket_group in state.bucket_groups: + bucket_group.finish_grad_sync() + + def zero_grad_buffer(self, tenant_id: Optional[str] = None): + """Zero gradient buffers for a tenant.""" + state = self._get_tenant_state(tenant_id) + + for param in state.params: + param.grad_added_to_main_grad = False + + for buffer in state.buffers: + buffer.reset() + + for bucket_group in state.bucket_groups: + bucket_group.reset() + + def scale_gradients(self, + scaling_factor: float, + tenant_id: Optional[str] = None): + """Scale gradients for a tenant.""" + state = self._get_tenant_state(tenant_id) + for buffer in state.buffers: + buffer.scale_gradients(scaling_factor) + + def broadcast_params(self, tenant_id: Optional[str] = None): + """Broadcast parameters for a tenant.""" + state = self._get_tenant_state(tenant_id) + for param in state.params: + dist.broadcast( + param.data, + src=dist.get_global_rank(state.process_group, 0), + group=state.process_group, + ) + + # ========== Utility ========== + + def has_tenant(self, tenant_id: str) -> bool: + """Check if tenant exists.""" + return tenant_id in self._tenant_states + + def list_tenants(self) -> List[str]: + """List all tenants.""" + return list(self._tenant_states.keys()) + + def get_tenant_params(self, + tenant_id: Optional[str] = None + ) -> List[nn.Parameter]: + """Get parameters for a tenant (requires valid tenant context).""" + state = self._get_tenant_state(tenant_id) + return state.params + + # Note: list_tenants() intentionally not exposed to prevent + # information leakage between tenants. Use has_tenant() instead. diff --git a/src/twinkle/megatron/distributed/tenant_context.py b/src/twinkle/megatron/distributed/tenant_context.py new file mode 100644 index 00000000..8ec5b2f6 --- /dev/null +++ b/src/twinkle/megatron/distributed/tenant_context.py @@ -0,0 +1,106 @@ +# Copyright (c) twinkle authors. All rights reserved. +""" +Tenant context management using ContextVar. + +This module provides process-level tenant context that automatically +propagates through async calls and threads, eliminating the need to +manually pass tenant_id to every method. +""" + +import contextvars +import uuid +from contextlib import contextmanager +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, TypeVar + +import torch.distributed as dist + +# Global ContextVar for current tenant - process level +_current_tenant: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar( + 'current_tenant', default=None +) + + +def get_current_tenant() -> Optional[str]: + """Get the current tenant ID from context.""" + return _current_tenant.get() + + +def set_current_tenant(tenant_id: Optional[str]) -> contextvars.Token: + """Set the current tenant ID in context.""" + return _current_tenant.set(tenant_id) + + +def require_tenant() -> str: + """Get current tenant ID, raising error if not set.""" + tenant_id = _current_tenant.get() + if tenant_id is None: + raise RuntimeError( + "No tenant context set. Use 'with tenant_scope(tenant_id):' or " + "call 'initialize()' first." + ) + return tenant_id + + +@contextmanager +def tenant_scope(tenant_id: str): + """ + Context manager to set the current tenant for a block of code. + + Example: + >>> with tenant_scope('user_a'): + ... model.forward(input) # Uses user_a's LoRA + ... loss.backward() + ... ddp.finish_grad_sync() # Only syncs user_a's gradients + """ + token = _current_tenant.set(tenant_id) + try: + yield tenant_id + finally: + _current_tenant.reset(token) + + +def generate_tenant_id() -> str: + """Generate a unique tenant ID.""" + return str(uuid.uuid4())[:8] + + +@dataclass +class TenantInfo: + """ + Information about a registered tenant. + + This is a lightweight dataclass that stores tenant metadata, + separate from DDP-specific state. + """ + tenant_id: str + adapter_name: str + process_group: Optional[dist.ProcessGroup] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + +F = TypeVar('F', bound=Callable) + + +def with_tenant_context(func: F) -> F: + """ + Decorator that automatically uses the current tenant context. + + The decorated function should have an optional 'tenant_id' parameter. + If not provided, it will use the current tenant from context. + + Example: + >>> @with_tenant_context + ... def finish_grad_sync(self, tenant_id: Optional[str] = None): + ... # tenant_id is automatically set from context if None + ... ... + """ + import functools + + @functools.wraps(func) + def wrapper(*args, tenant_id: Optional[str] = None, **kwargs): + if tenant_id is None: + tenant_id = require_tenant() + return func(*args, tenant_id=tenant_id, **kwargs) + + return wrapper # type: ignore diff --git a/src/twinkle/megatron/distributed/tenant_manager.py b/src/twinkle/megatron/distributed/tenant_manager.py new file mode 100644 index 00000000..f378e0e2 --- /dev/null +++ b/src/twinkle/megatron/distributed/tenant_manager.py @@ -0,0 +1,268 @@ +# Copyright (c) twinkle authors. All rights reserved. +""" +Tenant Manager for multi-tenant LoRA training. + +This module provides tenant lifecycle management +""" + +import logging +from contextlib import contextmanager +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, Type + +import torch +import torch.distributed as dist +import torch.nn as nn + +from .tenant_context import ( + generate_tenant_id, + get_current_tenant, + require_tenant, + set_current_tenant, + tenant_scope, +) + +logger = logging.getLogger(__name__) + + +from peft import LoraConfig, PeftModel + +@dataclass +class TenantState: + """ + State for a single tenant. + + Contains: + - Identity: tenant_id, adapter_name + - Training: optimizer, scheduler, params + - Config: gradient accumulation, max grad norm + """ + tenant_id: str + adapter_name: str + + # Parameters + params: List[nn.Parameter] = field(default_factory=list) + param_names: Dict[nn.Parameter, str] = field(default_factory=dict) + + # Training components + optimizer: Optional[torch.optim.Optimizer] = None + scheduler: Optional[Any] = None + + # Training config + gradient_accumulation_steps: int = 1 + max_grad_norm: float = 1.0 + + # Process group for this tenant + process_group: Optional[dist.ProcessGroup] = None + + +class TenantManager: + """ + Manages tenant lifecycle for multi-tenant training. + + Responsibilities: + 1. Tenant registration/deregistration + 2. LoRA adapter management + 3. Optimizer/scheduler creation + 4. Tenant context switching + + This class is decoupled from DDP - it only manages tenant metadata + and training components, not gradient buffers or communication. + + Example: + >>> manager = TenantManager(model) + >>> + >>> # Initialize tenant + >>> tenant_id = manager.initialize( + ... lora_config=LoraConfig(r=8), + ... optimizer_cls=AdamW, + ... ) + >>> + >>> # Use tenant context + >>> with manager.scope(tenant_id): + ... # All operations use this tenant + ... pass + >>> + >>> # Cleanup + >>> manager.finalize(tenant_id) + """ + + def __init__( + self, + model: nn.Module, + default_process_group: Optional[dist.ProcessGroup] = None, + ): + """ + Initialize tenant manager. + + Args: + model: Model with LoRA structure. + default_process_group: Default process group for tenants. + """ + self.model = model + self.default_process_group = default_process_group + self._tenants: Dict[str, TenantState] = {} + + # Callbacks for DDP integration + self._on_add_callbacks: List[Callable[[TenantState], None]] = [] + self._on_remove_callbacks: List[Callable[[TenantState], None]] = [] + + def register_add_callback(self, callback: Callable[[TenantState], None]): + """Register callback to be called when tenant is added.""" + self._on_add_callbacks.append(callback) + + def register_remove_callback(self, callback: Callable[[TenantState], None]): + """Register callback to be called when tenant is removed.""" + self._on_remove_callbacks.append(callback) + + def initialize( + self, + lora_config: Optional['LoraConfig'] = None, + optimizer_cls: Type[torch.optim.Optimizer] = torch.optim.AdamW, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + scheduler_cls: Optional[Type] = None, + scheduler_kwargs: Optional[Dict[str, Any]] = None, + gradient_accumulation_steps: int = 1, + max_grad_norm: float = 1.0, + process_group: Optional[dist.ProcessGroup] = None, + adapter_name: Optional[str] = None, + tenant_id: Optional[str] = None, + ) -> str: + """ + Initialize a new tenant. + + Args: + lora_config: LoRA configuration. + optimizer_cls: Optimizer class. + optimizer_kwargs: Optimizer arguments. + scheduler_cls: Scheduler class. + scheduler_kwargs: Scheduler arguments. + gradient_accumulation_steps: Steps to accumulate. + max_grad_norm: Max gradient norm for clipping. + process_group: Process group for gradient sync. + adapter_name: Adapter name (defaults to tenant_id). + tenant_id: Tenant ID (generated if not provided). + + Returns: + The tenant ID. + """ + tenant_id = tenant_id or generate_tenant_id() + adapter_name = adapter_name or tenant_id + process_group = process_group or self.default_process_group + + if tenant_id in self._tenants: + raise ValueError(f"Tenant '{tenant_id}' already exists") + + # Add LoRA adapter + if lora_config is not None and isinstance(self.model, PeftModel): + lora_config.modules_to_save = None + lora_config.bias = 'none' + self.model.add_adapter(adapter_name, lora_config) + logger.info(f"Added LoRA adapter '{adapter_name}'") + + # Find trainable params + params = [] + param_names = {} + + for name, param in self.model.named_parameters(): + if f'.{adapter_name}.' in name and 'lora_' in name: + param.requires_grad = True + params.append(param) + param_names[param] = name + + if not params: + logger.warning(f"No trainable params found for tenant '{tenant_id}'") + + # Create optimizer + optimizer_kwargs = optimizer_kwargs or {'lr': 1e-4} + optimizer = optimizer_cls(params, **optimizer_kwargs) if params else None + + # Create scheduler + scheduler = None + if scheduler_cls and optimizer: + scheduler_kwargs = scheduler_kwargs or {} + scheduler = scheduler_cls(optimizer, **scheduler_kwargs) + + # Create state + state = TenantState( + tenant_id=tenant_id, + adapter_name=adapter_name, + params=params, + param_names=param_names, + optimizer=optimizer, + scheduler=scheduler, + gradient_accumulation_steps=gradient_accumulation_steps, + max_grad_norm=max_grad_norm, + process_group=process_group, + ) + + self._tenants[tenant_id] = state + + # Notify callbacks (for DDP integration) + for callback in self._on_add_callbacks: + callback(state) + + # Set as current tenant + set_current_tenant(tenant_id) + + logger.info( + f"Initialized tenant '{tenant_id}' with {len(params)} params " + f"({sum(p.numel() for p in params):,} elements)" + ) + + return tenant_id + + def finalize(self, tenant_id: Optional[str] = None): + """ + Finalize a tenant and cleanup resources. + + Args: + tenant_id: Tenant to finalize. Uses current if None. + """ + tenant_id = tenant_id or get_current_tenant() + if not tenant_id or tenant_id not in self._tenants: + return + + state = self._tenants.pop(tenant_id) + + # Notify callbacks (for DDP cleanup) + for callback in self._on_remove_callbacks: + callback(state) + + # Remove adapter + if isinstance(self.model, PeftModel): + try: + self.model.delete_adapter(state.adapter_name) + except Exception as e: + logger.warning(f"Failed to delete adapter: {e}") + + # Clear context if current + if get_current_tenant() == tenant_id: + set_current_tenant(None) + + logger.info(f"Finalized tenant '{tenant_id}'") + + @contextmanager + def scope(self, tenant_id: Optional[str] = None): + """Context manager for tenant scope.""" + tenant_id = tenant_id or require_tenant() + with tenant_scope(tenant_id): + yield self.get(tenant_id) + + def get(self, tenant_id: Optional[str] = None) -> TenantState: + """Get tenant state.""" + tenant_id = tenant_id or require_tenant() + if tenant_id not in self._tenants: + raise KeyError(f"Tenant '{tenant_id}' not found") + return self._tenants[tenant_id] + + def has(self, tenant_id: str) -> bool: + """Check if tenant exists.""" + return tenant_id in self._tenants + + def count(self) -> int: + """Number of tenants (does not expose tenant IDs for privacy).""" + return len(self._tenants) + + # Note: list() method intentionally not exposed to clients to prevent + # information leakage. Only server-side code should enumerate tenants. diff --git a/src/twinkle/megatron/model/__init__.py b/src/twinkle/megatron/model/__init__.py index a4e289f2..e3bc5599 100644 --- a/src/twinkle/megatron/model/__init__.py +++ b/src/twinkle/megatron/model/__init__.py @@ -2,28 +2,18 @@ """Megatron model initialization and weight conversion. This module provides independent implementation for weight loading/saving, +and multi-tenant model wrapper for LoRA training. """ -from .bridge import ( - # Main classes - TwinkleBridgeAdapter, - TwinkleBridgeInitializer, - TwinkleGPTBridge, - BridgeConfig, - SafetensorLoader, - StreamingSafetensorSaver, - LazyTensor, - # Helper functions - deep_getattr, - is_last_rank, - load_hf_weights_to_megatron, - # Legacy compatibility - create_megatron_args, - set_megatron_args, - restore_megatron_args, - mock_megatron_args, -) +from .bridge import ( # Main classes; Helper functions; Legacy compatibility + BridgeConfig, LazyTensor, SafetensorLoader, StreamingSafetensorSaver, + TwinkleBridgeAdapter, TwinkleBridgeInitializer, TwinkleGPTBridge, + create_megatron_args, deep_getattr, is_last_rank, + load_hf_weights_to_megatron, mock_megatron_args, restore_megatron_args, + set_megatron_args) from .initializer import MegatronModelInitializer, initialize_megatron_model +from .multi_tenant_megatron import (MegatronMultiAdapter, + MultiTenantMegatronModel) from .qwen3 import Qwen3ModelMeta, get_model_default_config __all__ = [ @@ -50,4 +40,7 @@ # Model metadata 'Qwen3ModelMeta', 'get_model_default_config', + # Multi-tenant + 'MultiTenantMegatronModel', + 'MegatronMultiAdapter', ] diff --git a/src/twinkle/megatron/model/bridge.py b/src/twinkle/megatron/model/bridge.py index 1d21a0bb..ea1face1 100644 --- a/src/twinkle/megatron/model/bridge.py +++ b/src/twinkle/megatron/model/bridge.py @@ -1,20 +1,22 @@ # Copyright (c) twinkle authors. All rights reserved. # GPT Bridge for HuggingFace to Megatron-Core weight conversion. -from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union -from types import SimpleNamespace -from dataclasses import dataclass, field -from copy import copy -import os -import json import glob +import json import math +import os +from copy import copy +from dataclasses import dataclass, field +from types import SimpleNamespace +from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union import torch +import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F -import torch.distributed as dist from tqdm import tqdm + from twinkle.hub import HubOperation + try: from megatron.core import parallel_state as mpu MEGATRON_AVAILABLE = True @@ -42,14 +44,14 @@ def deep_getattr(obj, attr: str, default=None): def is_last_rank() -> bool: """Check if current process is the last rank for writing. - + For DP > 1, we want only DP rank 0 to write to avoid conflicts. For PP, we want the last PP stage. For TP, all TP ranks participate in gather, but only one writes. """ if not dist.is_initialized(): return True - + try: from megatron.core import parallel_state as mpu if mpu.is_initialized(): @@ -62,17 +64,16 @@ def is_last_rank() -> bool: return True except (ImportError, AssertionError): pass - + return dist.get_rank() == dist.get_world_size() - 1 class LazyTensor: """Lazy tensor wrapper for deferred loading.""" - def __init__(self, loader, key: str): self._loader = loader self._key = key - + def load(self) -> torch.Tensor: """Load the tensor.""" return self._loader.get_tensor(self._key) @@ -80,7 +81,6 @@ def load(self) -> torch.Tensor: class SafetensorLoader: """Lazy loader for safetensor files.""" - def __init__(self, model_dir: str, is_peft_format: bool = False): self.model_dir = model_dir self.is_peft_format = is_peft_format @@ -88,21 +88,23 @@ def __init__(self, model_dir: str, is_peft_format: bool = False): self._index = None self._key_to_file = {} self._load_index() - + def _load_index(self): """Load safetensor index file if exists.""" # Try adapter format first for PEFT if self.is_peft_format: - adapter_file = os.path.join(self.model_dir, 'adapter_model.safetensors') + adapter_file = os.path.join(self.model_dir, + 'adapter_model.safetensors') if os.path.exists(adapter_file): handle = safe_open(adapter_file, framework='pt', device='cpu') for key in handle.keys(): self._key_to_file[key] = adapter_file self._handles[adapter_file] = handle return - + # Standard index file - index_file = os.path.join(self.model_dir, 'model.safetensors.index.json') + index_file = os.path.join(self.model_dir, + 'model.safetensors.index.json') if os.path.exists(index_file): with open(index_file, 'r') as f: self._index = json.load(f) @@ -118,123 +120,128 @@ def _load_index(self): self._handles[single_file] = handle else: # Try to find any safetensor file - files = glob.glob(os.path.join(self.model_dir, '*.safetensors')) + files = glob.glob(os.path.join(self.model_dir, + '*.safetensors')) for filepath in files: handle = safe_open(filepath, framework='pt', device='cpu') for key in handle.keys(): self._key_to_file[key] = filepath self._handles[filepath] = handle - + def _get_handle(self, filepath: str): """Get or create file handle.""" if filepath not in self._handles: - self._handles[filepath] = safe_open(filepath, framework='pt', device='cpu') + self._handles[filepath] = safe_open(filepath, + framework='pt', + device='cpu') return self._handles[filepath] - + def get_tensor(self, key: str) -> torch.Tensor: """Load a single tensor.""" filepath = self._key_to_file.get(key) if filepath is None: - raise KeyError(f"Tensor key not found: {key}") + raise KeyError(f'Tensor key not found: {key}') handle = self._get_handle(filepath) return handle.get_tensor(key) - + def get_lazy(self, key: str) -> LazyTensor: """Get a lazy tensor reference.""" if key not in self._key_to_file: - raise KeyError(f"Tensor key not found: {key}") + raise KeyError(f'Tensor key not found: {key}') return LazyTensor(self, key) - + def get_state_dict(self) -> Dict[str, LazyTensor]: """Get lazy state dict.""" return {key: LazyTensor(self, key) for key in self._key_to_file} - + def keys(self) -> List[str]: """Get all tensor keys.""" return list(self._key_to_file.keys()) - + def __contains__(self, key: str) -> bool: return key in self._key_to_file - + def close(self): """Close all file handles.""" self._handles.clear() - + def __enter__(self): return self - + def __exit__(self, *args): self.close() class StreamingSafetensorSaver: """Streaming saver for safetensor files.""" - - def __init__(self, save_dir: str, max_shard_size: str = '5GB', is_peft_format: bool = False): + def __init__(self, + save_dir: str, + max_shard_size: str = '5GB', + is_peft_format: bool = False): self.save_dir = save_dir self.is_peft_format = is_peft_format os.makedirs(save_dir, exist_ok=True) - + # Parse max shard size size_str = max_shard_size.upper() if size_str.endswith('GB'): - self.max_shard_bytes = int(float(size_str[:-2]) * 1024 ** 3) + self.max_shard_bytes = int(float(size_str[:-2]) * 1024**3) elif size_str.endswith('MB'): - self.max_shard_bytes = int(float(size_str[:-2]) * 1024 ** 2) + self.max_shard_bytes = int(float(size_str[:-2]) * 1024**2) else: self.max_shard_bytes = int(size_str) - + self.current_shard = {} self.current_shard_size = 0 self.shard_idx = 1 self.weight_map = {} - + def add_tensor(self, key: str, tensor: torch.Tensor): """Add tensor to the current shard.""" if tensor is None: return - + tensor_size = tensor.numel() * tensor.element_size() - + # Flush if needed if self.current_shard_size + tensor_size > self.max_shard_bytes and self.current_shard: self._flush_shard() - + self.current_shard[key] = tensor.contiguous() self.current_shard_size += tensor_size - + def _flush_shard(self): """Flush current shard to disk.""" if not self.current_shard: return - + if self.is_peft_format: filename = 'adapter_model.safetensors' else: filename = f'model-{self.shard_idx:05d}-of-XXXXX.safetensors' - + filepath = os.path.join(self.save_dir, filename) save_file(self.current_shard, filepath) - + for key in self.current_shard: self.weight_map[key] = filename - + self.current_shard = {} self.current_shard_size = 0 self.shard_idx += 1 - + def finalize(self): """Finalize and write index.""" self._flush_shard() - + if self.is_peft_format: return # PEFT format doesn't need index - + # Fix shard filenames total_shards = self.shard_idx - 1 if total_shards == 0: return - + for old_name in list(self.weight_map.values()): new_name = old_name.replace('XXXXX', f'{total_shards:05d}') if old_name != new_name: @@ -245,14 +252,19 @@ def finalize(self): for key in self.weight_map: if self.weight_map[key] == old_name: self.weight_map[key] = new_name - + if total_shards > 1: index = { - 'metadata': {'total_size': sum(t.numel() * t.element_size() - for t in self.current_shard.values())}, + 'metadata': { + 'total_size': + sum(t.numel() * t.element_size() + for t in self.current_shard.values()) + }, 'weight_map': self.weight_map } - with open(os.path.join(self.save_dir, 'model.safetensors.index.json'), 'w') as f: + with open( + os.path.join(self.save_dir, + 'model.safetensors.index.json'), 'w') as f: json.dump(index, f, indent=2) @@ -264,7 +276,7 @@ class BridgeConfig: pp_size: int = 1 ep_size: int = 1 etp_size: int = 1 - + # Model architecture hidden_size: int = 4096 num_attention_heads: int = 32 @@ -274,21 +286,21 @@ class BridgeConfig: padded_vocab_size: int = 32000 intermediate_size: int = 11008 kv_channels: int = None # head_dim, if None will be computed from hidden_size // num_attention_heads - + # Options add_qkv_bias: bool = False add_bias_linear: bool = False qk_layernorm: bool = False tie_word_embeddings: bool = False - + # MoE num_experts: int = 0 num_experts_per_tok: int = 2 shared_expert_intermediate_size: int = 0 - + model_type: str = 'qwen2' max_shard_size: str = '5GB' - + @classmethod def from_hf_config( cls, @@ -305,18 +317,20 @@ def from_hf_config( # Pad to multiple of 64 for efficiency if padded_vocab_size % 64 != 0: padded_vocab_size = ((padded_vocab_size // 64) + 1) * 64 - + num_attention_heads = getattr(hf_config, 'num_attention_heads', 32) - num_key_value_heads = getattr(hf_config, 'num_key_value_heads', num_attention_heads) - + num_key_value_heads = getattr(hf_config, 'num_key_value_heads', + num_attention_heads) + # MoE config num_experts = getattr(hf_config, 'num_experts', 0) or \ getattr(hf_config, 'n_routed_experts', 0) or \ getattr(hf_config, 'num_local_experts', 0) num_experts_per_tok = getattr(hf_config, 'num_experts_per_tok', 2) or \ getattr(hf_config, 'moe_topk', 2) - shared_expert_size = getattr(hf_config, 'shared_expert_intermediate_size', 0) - + shared_expert_size = getattr(hf_config, + 'shared_expert_intermediate_size', 0) + # Determine QKV bias setting # Qwen2 has attention bias by default (hardcoded in transformers), # but config doesn't have 'attention_bias' field @@ -328,7 +342,7 @@ def from_hf_config( add_qkv_bias = True else: add_qkv_bias = False - + # Determine QK layernorm setting # Qwen3 uses QK layernorm but doesn't have explicit config attribute qk_layernorm = getattr(hf_config, 'qk_layernorm', False) or \ @@ -336,10 +350,10 @@ def from_hf_config( if not qk_layernorm and model_type in ('qwen3', 'qwen3_moe'): # Qwen3 (dense and MoE) always uses QK layernorm (q_norm, k_norm weights) qk_layernorm = True - + # Determine kv_channels (head_dim) - Qwen3 has explicit head_dim kv_channels = getattr(hf_config, 'head_dim', None) - + return cls( tp_size=tp_size, pp_size=pp_size, @@ -355,7 +369,8 @@ def from_hf_config( add_qkv_bias=add_qkv_bias, add_bias_linear=getattr(hf_config, 'mlp_bias', False), qk_layernorm=qk_layernorm, - tie_word_embeddings=getattr(hf_config, 'tie_word_embeddings', False), + tie_word_embeddings=getattr(hf_config, 'tie_word_embeddings', + False), num_experts=num_experts, num_experts_per_tok=num_experts_per_tok, shared_expert_intermediate_size=shared_expert_size, @@ -366,19 +381,22 @@ def from_hf_config( class TwinkleGPTBridge: """Bridge for converting weights between HuggingFace and Megatron-Core formats. - + Supports Qwen2.5 / Qwen3 model families. """ - + # HuggingFace model structure constants (Qwen2/Qwen3 compatible) HF_LAYERS_PREFIX = 'model.layers' HF_EMBED_KEY = 'model.embed_tokens.weight' HF_FINAL_LAYERNORM_KEY = 'model.norm.weight' HF_LM_HEAD_KEY = 'lm_head.weight' - - def __init__(self, config: BridgeConfig, hf_config: Any = None, disable_tqdm: bool = False): + + def __init__(self, + config: BridgeConfig, + hf_config: Any = None, + disable_tqdm: bool = False): """Initialize the bridge. - + Args: config: Bridge configuration. hf_config: HuggingFace model config (for reference). @@ -387,13 +405,13 @@ def __init__(self, config: BridgeConfig, hf_config: Any = None, disable_tqdm: bo self.config = config self.hf_config = hf_config self.disable_tqdm = disable_tqdm or not is_last_rank() - + # Parallel state self.tp_size = config.tp_size self.pp_size = config.pp_size self.ep_size = config.ep_size self.etp_size = config.etp_size - + # Get parallel ranks if MEGATRON_AVAILABLE and mpu.is_initialized(): self.tp_rank = mpu.get_tensor_model_parallel_rank() @@ -419,7 +437,7 @@ def __init__(self, config: BridgeConfig, hf_config: Any = None, disable_tqdm: bo self.ep_group = None self.etp_rank = 0 self.etp_group = None - + # PEFT tracking self._is_peft_format = False self._adapter_name = 'default' @@ -427,28 +445,32 @@ def __init__(self, config: BridgeConfig, hf_config: Any = None, disable_tqdm: bo self._peft_modules_to_save: Set[str] = set() self._target_device = None self._only_last_rank = False - + def _get_tp_split_dim(self, mg_key: Optional[str]) -> Optional[int]: """Determine which dimension to split for tensor parallelism.""" if mg_key is None: return None - + # ColumnParallel (split output dim) dim0_keys = { - 'word_embeddings', 'linear_qkv', 'output_layer', - 'linear_q_proj', 'linear_q_up_proj', 'linear_kv_up_proj', + 'word_embeddings', + 'linear_qkv', + 'output_layer', + 'linear_q_proj', + 'linear_q_up_proj', + 'linear_kv_up_proj', 'eh_proj', # MTP } # RowParallel (split input dim) dim1_keys = {'linear_proj', 'linear_fc2'} - + # Handle LoRA keys if 'lora_A' not in mg_key and 'lora_B' not in mg_key: key_parts = mg_key.rsplit('.', 2) if len(key_parts) >= 2: key = key_parts[-2] suffix = key_parts[-1] - + if suffix == 'layer_norm_weight': return None elif key in dim0_keys: @@ -471,31 +493,36 @@ def _get_tp_split_dim(self, mg_key: Optional[str]) -> Optional[int]: return 0 elif key == 'linear_fc1': return 1 - + return None - - def _split_tp(self, tensor: torch.Tensor, tp_dim: Optional[int], is_expert: bool = False) -> torch.Tensor: + + def _split_tp(self, + tensor: torch.Tensor, + tp_dim: Optional[int], + is_expert: bool = False) -> torch.Tensor: """Split tensor for tensor parallelism.""" tp_size = self.etp_size if is_expert else self.tp_size tp_rank = self.etp_rank if is_expert else self.tp_rank - + if tp_dim is None or tp_size <= 1: return tensor return tensor.chunk(tp_size, dim=tp_dim)[tp_rank] - - def _all_gather_tp(self, tensor: Optional[torch.Tensor], tp_dim: Optional[int], + + def _all_gather_tp(self, + tensor: Optional[torch.Tensor], + tp_dim: Optional[int], is_expert: bool = False) -> Optional[torch.Tensor]: """All-gather tensor across tensor parallel group.""" if tensor is None: return None - + tensor = tensor.to('cuda') tp_size = self.etp_size if is_expert else self.tp_size tp_group = self.etp_group if is_expert else self.tp_group - + if tp_dim is None or tp_size <= 1: return tensor - + if tp_dim == 0: tensor_shape = list(tensor.shape) tensor_shape[0] *= tp_size @@ -506,7 +533,7 @@ def _all_gather_tp(self, tensor: Optional[torch.Tensor], tp_dim: Optional[int], output = [torch.empty_like(tensor) for _ in range(tp_size)] dist.all_gather(output, tensor, group=tp_group) return torch.cat(output, dim=tp_dim) - + def _set_weight( self, mg_param: Union[torch.Tensor, nn.Parameter, List], @@ -517,15 +544,15 @@ def _set_weight( """Set weight from HuggingFace to Megatron parameter.""" tp_dim = self._get_tp_split_dim(mg_key) tensor = self._split_tp(hf_weight, tp_dim, is_expert) - + if not isinstance(mg_param, (list, tuple)): mg_param = [mg_param] - + tensor_list = tensor.chunk(len(mg_param), dim=0) for i, param in enumerate(mg_param): t = tensor_list[i].reshape(*param.shape) param.data.copy_(t) - + def _get_weight( self, mg_weight: Optional[Union[torch.Tensor, List[torch.Tensor]]], @@ -535,133 +562,141 @@ def _get_weight( """Get weight from Megatron parameter, gathered across TP.""" if mg_weight is None: return None, None - + tensor = mg_weight if not isinstance(tensor, (list, tuple)): tensor = [tensor] - + tensor = torch.cat(tensor, dim=0) tp_dim = self._get_tp_split_dim(mg_key) tensor = self._all_gather_tp(tensor, tp_dim, is_expert) - + if self._target_device is not None and tensor is not None: tensor = tensor.to(device=self._target_device) - + if self._only_last_rank and not is_last_rank(): return None, None - + return tensor, None - + # ========================================================================= # Weight Loading Methods # ========================================================================= - + def _load_embedding(self, mg_model, loader: SafetensorLoader): """Load embedding weights.""" embed_module = deep_getattr(mg_model, 'embedding.word_embeddings') if embed_module is None: return - + hf_weight = loader.get_tensor(self.HF_EMBED_KEY) - + # Pad vocabulary if needed if hf_weight.shape[0] < self.config.padded_vocab_size: hf_weight = F.pad( - hf_weight, - (0, 0, 0, self.config.padded_vocab_size - hf_weight.shape[0]) - ) - - self._set_weight(embed_module.weight, hf_weight, 'word_embeddings.weight') - + hf_weight, + (0, 0, 0, self.config.padded_vocab_size - hf_weight.shape[0])) + + self._set_weight(embed_module.weight, hf_weight, + 'word_embeddings.weight') + def _load_output_layer(self, mg_model, loader: SafetensorLoader): """Load output layer (lm_head) weights.""" output_module = deep_getattr(mg_model, 'output_layer') if output_module is None or output_module.weight is None: return - + # Check if weights are tied if self.config.tie_word_embeddings: hf_weight = loader.get_tensor(self.HF_EMBED_KEY) else: hf_weight = loader.get_tensor(self.HF_LM_HEAD_KEY) - + # Pad vocabulary if needed if hf_weight.shape[0] < self.config.padded_vocab_size: hf_weight = F.pad( hf_weight, - (0, 0, 0, self.config.padded_vocab_size - hf_weight.shape[0]) - ) - - self._set_weight(output_module.weight, hf_weight, 'output_layer.weight') - + (0, 0, 0, self.config.padded_vocab_size - hf_weight.shape[0])) + + self._set_weight(output_module.weight, hf_weight, + 'output_layer.weight') + def _load_final_layernorm(self, mg_model, loader: SafetensorLoader): """Load final layer norm weights.""" ln_module = deep_getattr(mg_model, 'decoder.final_layernorm') if ln_module is None: return - + hf_weight = loader.get_tensor(self.HF_FINAL_LAYERNORM_KEY) ln_module.weight.data.copy_(hf_weight) - - def _load_attention(self, mg_layer, loader: SafetensorLoader, layer_idx: int): + + def _load_attention(self, mg_layer, loader: SafetensorLoader, + layer_idx: int): """Load attention layer weights.""" mg_attn = mg_layer.self_attention prefix = f'{self.HF_LAYERS_PREFIX}.{layer_idx}.self_attn.' - + num_heads = self.config.num_attention_heads num_kv_heads = self.config.num_key_value_heads hidden_size = self.config.hidden_size # Use kv_channels (head_dim) from config if available (for Qwen3 etc.) - head_dim = getattr(self.config, 'kv_channels', hidden_size // num_heads) + head_dim = getattr(self.config, 'kv_channels', + hidden_size // num_heads) heads_per_group = num_heads // num_kv_heads - + # Load Q, K, V weights and merge into linear_qkv q_weight = loader.get_tensor(f'{prefix}q_proj.weight') k_weight = loader.get_tensor(f'{prefix}k_proj.weight') v_weight = loader.get_tensor(f'{prefix}v_proj.weight') - + # Infer head_dim from actual weight shapes if needed actual_kv_dim = k_weight.shape[0] // num_kv_heads if actual_kv_dim != head_dim: head_dim = actual_kv_dim - + # Reshape for GQA - q_weight = q_weight.reshape(num_kv_heads, heads_per_group * head_dim, hidden_size) + q_weight = q_weight.reshape(num_kv_heads, heads_per_group * head_dim, + hidden_size) k_weight = k_weight.reshape(num_kv_heads, head_dim, hidden_size) v_weight = v_weight.reshape(num_kv_heads, head_dim, hidden_size) - + qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=1) qkv_weight = qkv_weight.reshape(-1, hidden_size) - - self._set_weight(mg_attn.linear_qkv.weight, qkv_weight, 'linear_qkv.weight') - + + self._set_weight(mg_attn.linear_qkv.weight, qkv_weight, + 'linear_qkv.weight') + # Load O projection o_weight = loader.get_tensor(f'{prefix}o_proj.weight') - self._set_weight(mg_attn.linear_proj.weight, o_weight, 'linear_proj.weight') - + self._set_weight(mg_attn.linear_proj.weight, o_weight, + 'linear_proj.weight') + # Load biases if present if self.config.add_qkv_bias: try: q_bias = loader.get_tensor(f'{prefix}q_proj.bias') k_bias = loader.get_tensor(f'{prefix}k_proj.bias') v_bias = loader.get_tensor(f'{prefix}v_proj.bias') - + # Infer head_dim from actual bias shapes if needed actual_bias_head_dim = k_bias.shape[0] // num_kv_heads - - q_bias = q_bias.reshape(num_kv_heads, heads_per_group * actual_bias_head_dim) + + q_bias = q_bias.reshape(num_kv_heads, + heads_per_group * actual_bias_head_dim) k_bias = k_bias.reshape(num_kv_heads, actual_bias_head_dim) v_bias = v_bias.reshape(num_kv_heads, actual_bias_head_dim) - - qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).reshape(-1) - self._set_weight(mg_attn.linear_qkv.bias, qkv_bias, 'linear_qkv.bias') + + qkv_bias = torch.cat([q_bias, k_bias, v_bias], + dim=1).reshape(-1) + self._set_weight(mg_attn.linear_qkv.bias, qkv_bias, + 'linear_qkv.bias') except KeyError: pass - + # Load input layernorm (may be fused) ln_key = f'{self.HF_LAYERS_PREFIX}.{layer_idx}.input_layernorm.weight' ln_weight = loader.get_tensor(ln_key) - + ln_param = deep_getattr(mg_attn, 'linear_qkv.layer_norm_weight') if ln_param is not None: ln_param.data.copy_(ln_weight) @@ -669,7 +704,7 @@ def _load_attention(self, mg_layer, loader: SafetensorLoader, layer_idx: int): ln_module = deep_getattr(mg_layer, 'input_layernorm') if ln_module is not None: ln_module.weight.data.copy_(ln_weight) - + # QK layernorm (Qwen3) if self.config.qk_layernorm: try: @@ -683,41 +718,46 @@ def _load_attention(self, mg_layer, loader: SafetensorLoader, layer_idx: int): k_ln.weight.data.copy_(k_norm) except KeyError: pass - + def _load_mlp(self, mg_layer, loader: SafetensorLoader, layer_idx: int): """Load MLP layer weights.""" mg_mlp = mg_layer.mlp prefix = f'{self.HF_LAYERS_PREFIX}.{layer_idx}.mlp.' - + # Check if gate_up_proj is fused try: gate_weight = loader.get_tensor(f'{prefix}gate_proj.weight') up_weight = loader.get_tensor(f'{prefix}up_proj.weight') - + # Stack gate and up projections (shape: [2, intermediate, hidden]) fc1_weight = torch.stack([gate_weight, up_weight], dim=0) - self._set_weight(mg_mlp.linear_fc1.weight, fc1_weight, 'linear_fc1.weight') + self._set_weight(mg_mlp.linear_fc1.weight, fc1_weight, + 'linear_fc1.weight') except KeyError: # Try gate_up_proj (fused) try: - gate_up_weight = loader.get_tensor(f'{prefix}gate_up_proj.weight') - gate_up_weight = gate_up_weight.view(2, -1, gate_up_weight.shape[-1]) - self._set_weight(mg_mlp.linear_fc1.weight, gate_up_weight, 'linear_fc1.weight') + gate_up_weight = loader.get_tensor( + f'{prefix}gate_up_proj.weight') + gate_up_weight = gate_up_weight.view(2, -1, + gate_up_weight.shape[-1]) + self._set_weight(mg_mlp.linear_fc1.weight, gate_up_weight, + 'linear_fc1.weight') except KeyError: pass - + # Load down projection try: down_weight = loader.get_tensor(f'{prefix}down_proj.weight') - self._set_weight(mg_mlp.linear_fc2.weight, down_weight, 'linear_fc2.weight') + self._set_weight(mg_mlp.linear_fc2.weight, down_weight, + 'linear_fc2.weight') except KeyError: pass - + # Load post attention layernorm ln_key = f'{self.HF_LAYERS_PREFIX}.{layer_idx}.post_attention_layernorm.weight' try: ln_weight = loader.get_tensor(ln_key) - + ln_param = deep_getattr(mg_mlp, 'linear_fc1.layer_norm_weight') if ln_param is not None: ln_param.data.copy_(ln_weight) @@ -727,20 +767,20 @@ def _load_mlp(self, mg_layer, loader: SafetensorLoader, layer_idx: int): ln_module.weight.data.copy_(ln_weight) except KeyError: pass - + def _load_moe(self, mg_layer, loader: SafetensorLoader, layer_idx: int): """Load MoE layer weights. - + Handles Expert Parallel (EP) sharding - each EP rank loads only its assigned subset of experts based on ep_rank and ep_size. - + For EP=2 with 128 experts: - EP rank 0 loads experts 0-63 - EP rank 1 loads experts 64-127 """ mg_mlp = mg_layer.mlp prefix = f'{self.HF_LAYERS_PREFIX}.{layer_idx}.mlp.' - + # Load router (replicated across all ranks) try: router_key = None @@ -749,44 +789,57 @@ def _load_moe(self, mg_layer, loader: SafetensorLoader, layer_idx: int): if full_key in loader: router_key = full_key break - + if router_key: router_weight = loader.get_tensor(router_key) router_module = deep_getattr(mg_mlp, 'router') - if router_module is not None and hasattr(router_module, 'weight'): + if router_module is not None and hasattr( + router_module, 'weight'): router_module.weight.data.copy_(router_weight) - + # Load expert bias if present (for sigmoid routers like Qwen3) - for bias_key in ['gate.e_score_correction_bias', 'moe_statics.e_score_correction_bias']: + for bias_key in [ + 'gate.e_score_correction_bias', + 'moe_statics.e_score_correction_bias' + ]: full_bias_key = f'{prefix}{bias_key}' if full_bias_key in loader: try: expert_bias = loader.get_tensor(full_bias_key) - if router_module is not None and hasattr(router_module, 'expert_bias'): + if router_module is not None and hasattr( + router_module, 'expert_bias'): router_module.expert_bias.data.copy_(expert_bias) break except KeyError: continue except KeyError: pass - + # Load shared experts if present if self.config.shared_expert_intermediate_size > 0: - for shared_key in ['shared_expert', 'shared_experts', 'shared_mlp']: + for shared_key in [ + 'shared_expert', 'shared_experts', 'shared_mlp' + ]: try: - gate_weight = loader.get_tensor(f'{prefix}{shared_key}.gate_proj.weight') - up_weight = loader.get_tensor(f'{prefix}{shared_key}.up_proj.weight') - down_weight = loader.get_tensor(f'{prefix}{shared_key}.down_proj.weight') - + gate_weight = loader.get_tensor( + f'{prefix}{shared_key}.gate_proj.weight') + up_weight = loader.get_tensor( + f'{prefix}{shared_key}.up_proj.weight') + down_weight = loader.get_tensor( + f'{prefix}{shared_key}.down_proj.weight') + shared_module = deep_getattr(mg_mlp, 'shared_experts') if shared_module is not None: - fc1_weight = torch.stack([gate_weight, up_weight], dim=0) - self._set_weight(shared_module.linear_fc1.weight, fc1_weight, 'linear_fc1.weight') - self._set_weight(shared_module.linear_fc2.weight, down_weight, 'linear_fc2.weight') + fc1_weight = torch.stack([gate_weight, up_weight], + dim=0) + self._set_weight(shared_module.linear_fc1.weight, + fc1_weight, 'linear_fc1.weight') + self._set_weight(shared_module.linear_fc2.weight, + down_weight, 'linear_fc2.weight') break except KeyError: continue - + # Load shared expert gate if present for gate_key in ['shared_expert_gate.weight']: full_gate_key = f'{prefix}{gate_key}' @@ -794,17 +847,18 @@ def _load_moe(self, mg_layer, loader: SafetensorLoader, layer_idx: int): try: gate_weight = loader.get_tensor(full_gate_key) shared_module = deep_getattr(mg_mlp, 'shared_experts') - if shared_module is not None and hasattr(shared_module, 'gate_weight'): + if shared_module is not None and hasattr( + shared_module, 'gate_weight'): shared_module.gate_weight.data.copy_(gate_weight) break except KeyError: continue - + # Load experts with EP sharding num_local_experts = self.config.num_experts // self.ep_size start_expert_idx = self.ep_rank * num_local_experts experts_module = deep_getattr(mg_mlp, 'experts') - + if experts_module is not None: # Determine expert module type if hasattr(experts_module, 'weight1'): @@ -812,31 +866,40 @@ def _load_moe(self, mg_layer, loader: SafetensorLoader, layer_idx: int): # Need to collect all experts and set at once fc1_weights = [] # gate and up weights interleaved fc2_weights = [] # down weights - + for local_idx in range(num_local_experts): global_idx = start_expert_idx + local_idx try: - gate_weight = loader.get_tensor(f'{prefix}experts.{global_idx}.gate_proj.weight') - up_weight = loader.get_tensor(f'{prefix}experts.{global_idx}.up_proj.weight') - down_weight = loader.get_tensor(f'{prefix}experts.{global_idx}.down_proj.weight') - + gate_weight = loader.get_tensor( + f'{prefix}experts.{global_idx}.gate_proj.weight') + up_weight = loader.get_tensor( + f'{prefix}experts.{global_idx}.up_proj.weight') + down_weight = loader.get_tensor( + f'{prefix}experts.{global_idx}.down_proj.weight') + # Stack gate and up for gated linear unit fc1_weights.append(gate_weight) # [ffn_hidden, hidden] - fc1_weights.append(up_weight) # [ffn_hidden, hidden] + fc1_weights.append(up_weight) # [ffn_hidden, hidden] fc2_weights.append(down_weight) # [hidden, ffn_hidden] except KeyError as e: - print(f"Warning: Missing expert {global_idx} weights: {e}") + print( + f'Warning: Missing expert {global_idx} weights: {e}' + ) continue - + if fc1_weights and fc2_weights: # GroupedMLP weight1: [hidden, num_experts * 2 * ffn_hidden] (transposed) # HF format: [num_experts * 2, ffn_hidden, hidden] - fc1_stacked = torch.cat(fc1_weights, dim=0) # [num_experts*2*ffn_hidden, hidden] - fc1_stacked = fc1_stacked.t().contiguous() # [hidden, num_experts*2*ffn_hidden] - + fc1_stacked = torch.cat( + fc1_weights, + dim=0) # [num_experts*2*ffn_hidden, hidden] + fc1_stacked = fc1_stacked.t().contiguous( + ) # [hidden, num_experts*2*ffn_hidden] + # GroupedMLP weight2: [num_experts * ffn_hidden, hidden] - fc2_stacked = torch.cat(fc2_weights, dim=0) # [num_experts*hidden, ffn_hidden] - + fc2_stacked = torch.cat( + fc2_weights, dim=0) # [num_experts*hidden, ffn_hidden] + # Set weights directly if experts_module.weight1.shape == fc1_stacked.shape: experts_module.weight1.data.copy_(fc1_stacked) @@ -847,62 +910,84 @@ def _load_moe(self, mg_layer, loader: SafetensorLoader, layer_idx: int): if tp_size > 1: # Split along last dim for weight1 chunk_size = fc1_stacked.shape[1] // tp_size - fc1_chunk = fc1_stacked[:, tp_rank * chunk_size:(tp_rank + 1) * chunk_size] + fc1_chunk = fc1_stacked[:, tp_rank * + chunk_size:(tp_rank + 1) * + chunk_size] experts_module.weight1.data.copy_(fc1_chunk) else: experts_module.weight1.data.copy_(fc1_stacked) - + if experts_module.weight2.shape == fc2_stacked.shape: experts_module.weight2.data.copy_(fc2_stacked) else: - # Handle TP split + # Handle TP split tp_rank = self.tp_rank tp_size = self.tp_size if tp_size > 1: # Split along first dim for weight2 chunk_size = fc2_stacked.shape[0] // tp_size - fc2_chunk = fc2_stacked[tp_rank * chunk_size:(tp_rank + 1) * chunk_size, :] + fc2_chunk = fc2_stacked[tp_rank * + chunk_size:(tp_rank + 1) * + chunk_size, :] experts_module.weight2.data.copy_(fc2_chunk) else: experts_module.weight2.data.copy_(fc2_stacked) - + elif hasattr(experts_module, 'local_experts'): # SequentialMLP format with local_experts list for local_idx in range(num_local_experts): global_idx = start_expert_idx + local_idx try: - gate_weight = loader.get_tensor(f'{prefix}experts.{global_idx}.gate_proj.weight') - up_weight = loader.get_tensor(f'{prefix}experts.{global_idx}.up_proj.weight') - down_weight = loader.get_tensor(f'{prefix}experts.{global_idx}.down_proj.weight') - + gate_weight = loader.get_tensor( + f'{prefix}experts.{global_idx}.gate_proj.weight') + up_weight = loader.get_tensor( + f'{prefix}experts.{global_idx}.up_proj.weight') + down_weight = loader.get_tensor( + f'{prefix}experts.{global_idx}.down_proj.weight') + expert = experts_module.local_experts[local_idx] if hasattr(expert, 'linear_fc1'): - fc1_weight = torch.stack([gate_weight, up_weight], dim=0) - self._set_weight(expert.linear_fc1.weight, fc1_weight, 'linear_fc1.weight') - self._set_weight(expert.linear_fc2.weight, down_weight, 'linear_fc2.weight') + fc1_weight = torch.stack([gate_weight, up_weight], + dim=0) + self._set_weight(expert.linear_fc1.weight, + fc1_weight, 'linear_fc1.weight') + self._set_weight(expert.linear_fc2.weight, + down_weight, 'linear_fc2.weight') except KeyError: continue - + elif hasattr(experts_module, 'linear_fc1'): # TEGroupedLinear format - weights stored as weight0, weight1, etc. for local_idx in range(num_local_experts): global_idx = start_expert_idx + local_idx try: - gate_weight = loader.get_tensor(f'{prefix}experts.{global_idx}.gate_proj.weight') - up_weight = loader.get_tensor(f'{prefix}experts.{global_idx}.up_proj.weight') - down_weight = loader.get_tensor(f'{prefix}experts.{global_idx}.down_proj.weight') - - fc1_weight = torch.stack([gate_weight, up_weight], dim=0) - fc1_param = getattr(experts_module.linear_fc1, f'weight{local_idx}', None) + gate_weight = loader.get_tensor( + f'{prefix}experts.{global_idx}.gate_proj.weight') + up_weight = loader.get_tensor( + f'{prefix}experts.{global_idx}.up_proj.weight') + down_weight = loader.get_tensor( + f'{prefix}experts.{global_idx}.down_proj.weight') + + fc1_weight = torch.stack([gate_weight, up_weight], + dim=0) + fc1_param = getattr(experts_module.linear_fc1, + f'weight{local_idx}', None) if fc1_param is not None: - self._set_weight(fc1_param, fc1_weight, 'linear_fc1.weight', is_expert=True) - - fc2_param = getattr(experts_module.linear_fc2, f'weight{local_idx}', None) + self._set_weight(fc1_param, + fc1_weight, + 'linear_fc1.weight', + is_expert=True) + + fc2_param = getattr(experts_module.linear_fc2, + f'weight{local_idx}', None) if fc2_param is not None: - self._set_weight(fc2_param, down_weight, 'linear_fc2.weight', is_expert=True) + self._set_weight(fc2_param, + down_weight, + 'linear_fc2.weight', + is_expert=True) except KeyError: continue - + # Load post attention layernorm (pre_mlp_layernorm for MoE) ln_key = f'{self.HF_LAYERS_PREFIX}.{layer_idx}.post_attention_layernorm.weight' try: @@ -918,17 +1003,17 @@ def _load_moe(self, mg_layer, loader: SafetensorLoader, layer_idx: int): ln_param.data.copy_(ln_weight) except KeyError: pass - + def _load_layer(self, mg_layer, loader: SafetensorLoader, layer_idx: int): """Load a single transformer layer.""" self._load_attention(mg_layer, loader, layer_idx) - + # Check if MoE layer if self.config.num_experts > 0: self._load_moe(mg_layer, loader, layer_idx) else: self._load_mlp(mg_layer, loader, layer_idx) - + def load_weights( self, mg_model: nn.Module, @@ -937,7 +1022,7 @@ def load_weights( adapter_name: str = 'default', ) -> None: """Load HuggingFace weights into Megatron model. - + Args: mg_model: Megatron GPT model. model_path: Path to HuggingFace checkpoint. @@ -946,87 +1031,89 @@ def load_weights( """ self._is_peft_format = is_peft_format self._adapter_name = adapter_name - + with torch.no_grad(): - with SafetensorLoader(model_path, is_peft_format=is_peft_format) as loader: + with SafetensorLoader(model_path, + is_peft_format=is_peft_format) as loader: if is_peft_format: self._load_peft_weights(mg_model, loader) else: self._load_base_weights(mg_model, loader) - - def _load_base_weights(self, mg_model: nn.Module, loader: SafetensorLoader): + + def _load_base_weights(self, mg_model: nn.Module, + loader: SafetensorLoader): """Load base model weights.""" # Get decoder decoder = deep_getattr(mg_model, 'decoder') if decoder is None: decoder = mg_model - + layers = getattr(decoder, 'layers', []) - + # Load pre-process (embedding) on first PP rank if self.pp_size <= 1 or self.pp_rank == 0: try: self._load_embedding(mg_model, loader) except Exception as e: - print(f"Warning: Failed to load embedding: {e}") - + print(f'Warning: Failed to load embedding: {e}') + # Load transformer layers - prog_bar = tqdm( - layers, - desc='Loading weights', - disable=self.disable_tqdm - ) + prog_bar = tqdm(layers, + desc='Loading weights', + disable=self.disable_tqdm) for mg_layer in prog_bar: layer_idx = mg_layer.layer_number - 1 # 1-indexed to 0-indexed try: self._load_layer(mg_layer, loader, layer_idx) except Exception as e: - print(f"Warning: Failed to load layer {layer_idx}: {e}") - + print(f'Warning: Failed to load layer {layer_idx}: {e}') + # Load post-process on last PP rank if self.pp_size <= 1 or self.pp_rank == self.pp_size - 1: try: self._load_final_layernorm(mg_model, loader) self._load_output_layer(mg_model, loader) except Exception as e: - print(f"Warning: Failed to load post-process: {e}") - - def _load_peft_weights(self, mg_model: nn.Module, loader: SafetensorLoader): + print(f'Warning: Failed to load post-process: {e}') + + def _load_peft_weights(self, mg_model: nn.Module, + loader: SafetensorLoader): """Load PEFT/LoRA adapter weights.""" state_dict = loader.get_state_dict() hf_prefix = 'base_model.model.' if self._is_peft_format else '' - + # Build mapping from HF keys to Megatron keys for key, lazy_tensor in state_dict.items(): # Remove base_model.model. prefix if key.startswith(hf_prefix): key = key[len(hf_prefix):] - + # Parse the key to find target module if '.lora_A.' in key or '.lora_B.' in key: tensor = lazy_tensor.load() self._load_peft_tensor(mg_model, key, tensor) - - def _load_peft_tensor(self, mg_model: nn.Module, key: str, tensor: torch.Tensor): + + def _load_peft_tensor(self, mg_model: nn.Module, key: str, + tensor: torch.Tensor): """Load a single PEFT tensor into the model.""" # Parse key: model.layers.0.self_attn.q_proj.lora_A.weight parts = key.split('.') - + # Find layer index layer_idx = None for i, p in enumerate(parts): if p == 'layers' and i + 1 < len(parts): layer_idx = int(parts[i + 1]) break - + if layer_idx is None: return - + # Get layer decoder = deep_getattr(mg_model, 'decoder') if decoder is None: decoder = mg_model - + layers = getattr(decoder, 'layers', []) for layer in layers: if layer.layer_number - 1 == layer_idx: @@ -1034,11 +1121,11 @@ def _load_peft_tensor(self, mg_model: nn.Module, key: str, tensor: torch.Tensor) break else: return - + # Determine target and lora type is_lora_A = '.lora_A.' in key is_lora_B = '.lora_B.' in key - + if 'self_attn' in key: mg_attn = mg_layer.self_attention if 'q_proj' in key or 'k_proj' in key or 'v_proj' in key: @@ -1057,23 +1144,23 @@ def _load_peft_tensor(self, mg_model: nn.Module, key: str, tensor: torch.Tensor) return else: return - + if target is None: return - + # Get LoRA module if is_lora_A: lora_module = deep_getattr(target, f'lora_A.{self._adapter_name}') else: lora_module = deep_getattr(target, f'lora_B.{self._adapter_name}') - + if lora_module is not None and hasattr(lora_module, 'weight'): lora_module.weight.data.copy_(tensor) - + # ========================================================================= # Weight Saving Methods # ========================================================================= - + def export_weights( self, mg_models: Union[nn.Module, List[nn.Module]], @@ -1083,7 +1170,7 @@ def export_weights( tqdm_desc: str = 'Exporting: ', ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Export weights from Megatron model to HuggingFace format. - + Yields: Tuples of (key, tensor) for each weight. """ @@ -1093,52 +1180,58 @@ def export_weights( self._adapter_name = 'default' self._peft_target_modules = set() self._peft_modules_to_save = set() - + if not isinstance(mg_models, (list, tuple)): mg_models = [mg_models] - + hf_prefix = 'base_model.model.' if is_peft_format else '' - + with torch.no_grad(): # For now, handle single model mg_model = mg_models[0] - + decoder = deep_getattr(mg_model, 'decoder') if decoder is None: decoder = mg_model - + layers = getattr(decoder, 'layers', []) - + if not is_peft_format: # Export embedding if self.pp_size <= 1 or self.pp_rank == 0: - embed = deep_getattr(mg_model, 'embedding.word_embeddings.weight') + embed = deep_getattr(mg_model, + 'embedding.word_embeddings.weight') if embed is not None: - weight, _ = self._get_weight(embed.data, 'word_embeddings.weight') + weight, _ = self._get_weight(embed.data, + 'word_embeddings.weight') if weight is not None: weight = weight[:self.config.vocab_size] yield f'{hf_prefix}{self.HF_EMBED_KEY}', weight - + # Export layers prog_bar = tqdm(layers, desc=tqdm_desc, disable=self.disable_tqdm) for mg_layer in prog_bar: layer_idx = mg_layer.layer_number - 1 - yield from self._export_layer(mg_layer, layer_idx, hf_prefix, is_peft_format) - + yield from self._export_layer(mg_layer, layer_idx, hf_prefix, + is_peft_format) + if not is_peft_format: # Export final layernorm and output layer if self.pp_size <= 1 or self.pp_rank == self.pp_size - 1: - ln_module = deep_getattr(mg_model, 'decoder.final_layernorm') + ln_module = deep_getattr(mg_model, + 'decoder.final_layernorm') if ln_module is not None: - yield f'{hf_prefix}{self.HF_FINAL_LAYERNORM_KEY}', ln_module.weight.data.clone() - + yield f'{hf_prefix}{self.HF_FINAL_LAYERNORM_KEY}', ln_module.weight.data.clone( + ) + output = deep_getattr(mg_model, 'output_layer.weight') if output is not None: - weight, _ = self._get_weight(output.data, 'output_layer.weight') + weight, _ = self._get_weight(output.data, + 'output_layer.weight') if weight is not None: weight = weight[:self.config.vocab_size] yield f'{hf_prefix}{self.HF_LM_HEAD_KEY}', weight - + def _export_layer( self, mg_layer, @@ -1148,10 +1241,10 @@ def _export_layer( ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Export a single layer.""" prefix = f'{hf_prefix}{self.HF_LAYERS_PREFIX}.{layer_idx}.' - + mg_attn = mg_layer.self_attention mg_mlp = mg_layer.mlp - + num_heads = self.config.num_attention_heads num_kv_heads = self.config.num_key_value_heads hidden_size = self.config.hidden_size @@ -1159,44 +1252,62 @@ def _export_layer( heads_per_group = num_heads // num_kv_heads q_dim = heads_per_group * head_dim kv_dim = head_dim - + if not is_peft_format: # Export QKV - qkv_weight, _ = self._get_weight(mg_attn.linear_qkv.weight.data, 'linear_qkv.weight') + qkv_weight, _ = self._get_weight(mg_attn.linear_qkv.weight.data, + 'linear_qkv.weight') if qkv_weight is not None: qkv_weight = qkv_weight.reshape(num_kv_heads, -1, hidden_size) - yield f'{prefix}self_attn.q_proj.weight', qkv_weight[:, :q_dim, :].reshape(-1, hidden_size).clone() - yield f'{prefix}self_attn.k_proj.weight', qkv_weight[:, q_dim:q_dim+kv_dim, :].reshape(-1, hidden_size).clone() - yield f'{prefix}self_attn.v_proj.weight', qkv_weight[:, -kv_dim:, :].reshape(-1, hidden_size).clone() - + yield f'{prefix}self_attn.q_proj.weight', qkv_weight[:, : + q_dim, :].reshape( + -1, + hidden_size + ).clone() + yield f'{prefix}self_attn.k_proj.weight', qkv_weight[:, q_dim: + q_dim + + kv_dim, :].reshape( + -1, + hidden_size + ).clone() + yield f'{prefix}self_attn.v_proj.weight', qkv_weight[:, + -kv_dim:, :].reshape( + -1, + hidden_size + ).clone() + # Export O - o_weight, _ = self._get_weight(mg_attn.linear_proj.weight.data, 'linear_proj.weight') + o_weight, _ = self._get_weight(mg_attn.linear_proj.weight.data, + 'linear_proj.weight') if o_weight is not None: yield f'{prefix}self_attn.o_proj.weight', o_weight - + # Export layernorms ln = deep_getattr(mg_attn, 'linear_qkv.layer_norm_weight') if ln is not None: yield f'{prefix}input_layernorm.weight', ln.data.clone() - + # Export MLP - fc1_weight, _ = self._get_weight(mg_mlp.linear_fc1.weight.data, 'linear_fc1.weight') + fc1_weight, _ = self._get_weight(mg_mlp.linear_fc1.weight.data, + 'linear_fc1.weight') if fc1_weight is not None: fc1_weight = fc1_weight.view(2, -1, hidden_size) yield f'{prefix}mlp.gate_proj.weight', fc1_weight[0].clone() yield f'{prefix}mlp.up_proj.weight', fc1_weight[1].clone() - - fc2_weight, _ = self._get_weight(mg_mlp.linear_fc2.weight.data, 'linear_fc2.weight') + + fc2_weight, _ = self._get_weight(mg_mlp.linear_fc2.weight.data, + 'linear_fc2.weight') if fc2_weight is not None: yield f'{prefix}mlp.down_proj.weight', fc2_weight - + ln2 = deep_getattr(mg_mlp, 'linear_fc1.layer_norm_weight') if ln2 is not None: - yield f'{prefix}post_attention_layernorm.weight', ln2.data.clone() + yield f'{prefix}post_attention_layernorm.weight', ln2.data.clone( + ) else: # Export LoRA weights only yield from self._export_lora_layer(mg_attn, mg_mlp, prefix) - + def _export_lora_layer( self, mg_attn, @@ -1206,77 +1317,115 @@ def _export_lora_layer( """Export LoRA weights from a layer.""" # Check if LoRA is applied from twinkle.megatron.tuners import LoraParallelLinear - + # Attention LoRA if isinstance(mg_attn.linear_qkv, LoraParallelLinear): - lora_A = deep_getattr(mg_attn.linear_qkv, f'lora_A.{self._adapter_name}.weight') - lora_B = deep_getattr(mg_attn.linear_qkv, f'lora_B.{self._adapter_name}.weight') - + lora_A = deep_getattr(mg_attn.linear_qkv, + f'lora_A.{self._adapter_name}.weight') + lora_B = deep_getattr(mg_attn.linear_qkv, + f'lora_B.{self._adapter_name}.weight') + if lora_A is not None and lora_B is not None: - lora_A, _ = self._get_weight(lora_A.data, 'linear_qkv.lora_A.weight') - lora_B, _ = self._get_weight(lora_B.data, 'linear_qkv.lora_B.weight') - + lora_A, _ = self._get_weight(lora_A.data, + 'linear_qkv.lora_A.weight') + lora_B, _ = self._get_weight(lora_B.data, + 'linear_qkv.lora_B.weight') + if lora_A is not None: - self._peft_target_modules.update({'q_proj', 'k_proj', 'v_proj'}) + self._peft_target_modules.update( + {'q_proj', 'k_proj', 'v_proj'}) # Split lora_B for Q, K, V for key in ['q_proj', 'k_proj', 'v_proj']: - yield f'{prefix}self_attn.{key}.lora_A.weight', lora_A.clone() - + yield f'{prefix}self_attn.{key}.lora_A.weight', lora_A.clone( + ) + num_kv_heads = self.config.num_key_value_heads head_dim = self.config.hidden_size // self.config.num_attention_heads heads_per_group = self.config.num_attention_heads // num_kv_heads q_dim = heads_per_group * head_dim - + lora_B = lora_B.reshape(num_kv_heads, -1, lora_B.shape[-1]) - yield f'{prefix}self_attn.q_proj.lora_B.weight', lora_B[:, :q_dim, :].reshape(-1, lora_B.shape[-1]).clone() - yield f'{prefix}self_attn.k_proj.lora_B.weight', lora_B[:, q_dim:-head_dim, :].reshape(-1, lora_B.shape[-1]).clone() - yield f'{prefix}self_attn.v_proj.lora_B.weight', lora_B[:, -head_dim:, :].reshape(-1, lora_B.shape[-1]).clone() - + yield f'{prefix}self_attn.q_proj.lora_B.weight', lora_B[:, :q_dim, :].reshape( + -1, lora_B.shape[-1]).clone() + yield f'{prefix}self_attn.k_proj.lora_B.weight', lora_B[:, + q_dim: + -head_dim, :].reshape( + -1, + lora_B + . + shape[ + -1] + ).clone( + ) + yield f'{prefix}self_attn.v_proj.lora_B.weight', lora_B[:, -head_dim:, :].reshape( + -1, lora_B.shape[-1]).clone() + # O projection LoRA if isinstance(mg_attn.linear_proj, LoraParallelLinear): - lora_A = deep_getattr(mg_attn.linear_proj, f'lora_A.{self._adapter_name}.weight') - lora_B = deep_getattr(mg_attn.linear_proj, f'lora_B.{self._adapter_name}.weight') - + lora_A = deep_getattr(mg_attn.linear_proj, + f'lora_A.{self._adapter_name}.weight') + lora_B = deep_getattr(mg_attn.linear_proj, + f'lora_B.{self._adapter_name}.weight') + if lora_A is not None and lora_B is not None: - lora_A, _ = self._get_weight(lora_A.data, 'linear_proj.lora_A.weight') - lora_B, _ = self._get_weight(lora_B.data, 'linear_proj.lora_B.weight') - + lora_A, _ = self._get_weight(lora_A.data, + 'linear_proj.lora_A.weight') + lora_B, _ = self._get_weight(lora_B.data, + 'linear_proj.lora_B.weight') + if lora_A is not None: self._peft_target_modules.add('o_proj') - yield f'{prefix}self_attn.o_proj.lora_A.weight', lora_A.clone() - yield f'{prefix}self_attn.o_proj.lora_B.weight', lora_B.clone() - + yield f'{prefix}self_attn.o_proj.lora_A.weight', lora_A.clone( + ) + yield f'{prefix}self_attn.o_proj.lora_B.weight', lora_B.clone( + ) + # MLP LoRA - if hasattr(mg_mlp, 'linear_fc1') and isinstance(mg_mlp.linear_fc1, LoraParallelLinear): - lora_A = deep_getattr(mg_mlp.linear_fc1, f'lora_A.{self._adapter_name}.weight') - lora_B = deep_getattr(mg_mlp.linear_fc1, f'lora_B.{self._adapter_name}.weight') - + if hasattr(mg_mlp, 'linear_fc1') and isinstance( + mg_mlp.linear_fc1, LoraParallelLinear): + lora_A = deep_getattr(mg_mlp.linear_fc1, + f'lora_A.{self._adapter_name}.weight') + lora_B = deep_getattr(mg_mlp.linear_fc1, + f'lora_B.{self._adapter_name}.weight') + if lora_A is not None and lora_B is not None: - lora_A, _ = self._get_weight(lora_A.data, 'linear_fc1.lora_A.weight') - lora_B, _ = self._get_weight(lora_B.data, 'linear_fc1.lora_B.weight') - + lora_A, _ = self._get_weight(lora_A.data, + 'linear_fc1.lora_A.weight') + lora_B, _ = self._get_weight(lora_B.data, + 'linear_fc1.lora_B.weight') + if lora_A is not None: self._peft_target_modules.update({'gate_proj', 'up_proj'}) for key in ['gate_proj', 'up_proj']: - yield f'{prefix}mlp.{key}.lora_A.weight', lora_A.clone() - + yield f'{prefix}mlp.{key}.lora_A.weight', lora_A.clone( + ) + lora_B = lora_B.reshape(2, -1, lora_B.shape[-1]) - yield f'{prefix}mlp.gate_proj.lora_B.weight', lora_B[0].clone() - yield f'{prefix}mlp.up_proj.lora_B.weight', lora_B[1].clone() - - if hasattr(mg_mlp, 'linear_fc2') and isinstance(mg_mlp.linear_fc2, LoraParallelLinear): - lora_A = deep_getattr(mg_mlp.linear_fc2, f'lora_A.{self._adapter_name}.weight') - lora_B = deep_getattr(mg_mlp.linear_fc2, f'lora_B.{self._adapter_name}.weight') - + yield f'{prefix}mlp.gate_proj.lora_B.weight', lora_B[ + 0].clone() + yield f'{prefix}mlp.up_proj.lora_B.weight', lora_B[ + 1].clone() + + if hasattr(mg_mlp, 'linear_fc2') and isinstance( + mg_mlp.linear_fc2, LoraParallelLinear): + lora_A = deep_getattr(mg_mlp.linear_fc2, + f'lora_A.{self._adapter_name}.weight') + lora_B = deep_getattr(mg_mlp.linear_fc2, + f'lora_B.{self._adapter_name}.weight') + if lora_A is not None and lora_B is not None: - lora_A, _ = self._get_weight(lora_A.data, 'linear_fc2.lora_A.weight') - lora_B, _ = self._get_weight(lora_B.data, 'linear_fc2.lora_B.weight') - + lora_A, _ = self._get_weight(lora_A.data, + 'linear_fc2.lora_A.weight') + lora_B, _ = self._get_weight(lora_B.data, + 'linear_fc2.lora_B.weight') + if lora_A is not None: self._peft_target_modules.add('down_proj') - yield f'{prefix}mlp.down_proj.lora_A.weight', lora_A.clone() - yield f'{prefix}mlp.down_proj.lora_B.weight', lora_B.clone() - + yield f'{prefix}mlp.down_proj.lora_A.weight', lora_A.clone( + ) + yield f'{prefix}mlp.down_proj.lora_B.weight', lora_B.clone( + ) + def save_weights( self, mg_models: Union[nn.Module, List[nn.Module]], @@ -1284,21 +1433,21 @@ def save_weights( is_peft_format: bool = False, ) -> None: """Save Megatron model weights in HuggingFace format. - + Args: mg_models: Megatron model(s) to save. output_dir: Directory to save weights. is_peft_format: Whether saving in PEFT format. - + Note: For DP > 1, only DP rank 0 writes to disk. All ranks participate in tensor gather operations for TP. """ torch.cuda.empty_cache() - + # Determine if this rank should write should_write = is_last_rank() - + # Only the writing rank creates the saver saver = None if should_write: @@ -1307,37 +1456,40 @@ def save_weights( max_shard_size=self.config.max_shard_size, is_peft_format=is_peft_format, ) - + # All ranks participate in export (needed for TP gather) for key, tensor in self.export_weights( - mg_models, - target_device='cpu', - only_last_rank=True, - is_peft_format=is_peft_format, - tqdm_desc='Saving: ', + mg_models, + target_device='cpu', + only_last_rank=True, + is_peft_format=is_peft_format, + tqdm_desc='Saving: ', ): if saver is not None and tensor is not None: saver.add_tensor(key, tensor) - + if saver is not None: saver.finalize() - + # Save config on writing rank only if should_write: if is_peft_format and not isinstance(mg_models, (list, tuple)): mg_models = [mg_models] - + if is_peft_format and hasattr(mg_models[0], 'peft_config'): - peft_config = copy(mg_models[0].peft_config.get(self._adapter_name)) + peft_config = copy(mg_models[0].peft_config.get( + self._adapter_name)) if peft_config is not None: - peft_config.target_modules = list(self._peft_target_modules) - peft_config.modules_to_save = list(self._peft_modules_to_save) + peft_config.target_modules = list( + self._peft_target_modules) + peft_config.modules_to_save = list( + self._peft_modules_to_save) peft_config.save_pretrained(output_dir) elif not is_peft_format and self.hf_config is not None: # Save HF config self.hf_config.vocab_size = self.config.padded_vocab_size self.hf_config.save_pretrained(output_dir) - + # Synchronize all ranks before continuing if dist.is_initialized(): dist.barrier() @@ -1345,10 +1497,9 @@ def save_weights( class TwinkleBridgeAdapter: """Adapter for weight loading using TwinkleGPTBridge. - + Provides a simple interface for loading HF weights into Megatron models. """ - def __init__( self, hf_config: Any, @@ -1363,7 +1514,7 @@ def __init__( """Initialize the bridge adapter.""" self.hf_config = hf_config self.model_path = model_path - + # Create bridge config self.config = BridgeConfig.from_hf_config( hf_config=hf_config, @@ -1374,9 +1525,9 @@ def __init__( ) if etp_size is not None: self.config.etp_size = etp_size - + self._bridge = None - + def _get_bridge(self) -> TwinkleGPTBridge: """Get or create the bridge instance.""" if self._bridge is None: @@ -1385,7 +1536,7 @@ def _get_bridge(self) -> TwinkleGPTBridge: hf_config=self.hf_config, ) return self._bridge - + def load_weights( self, mg_model: nn.Module, @@ -1396,11 +1547,11 @@ def load_weights( """Load HuggingFace weights into Megatron model.""" model_path = model_path or self.model_path if model_path is None: - raise ValueError("model_path must be provided") - + raise ValueError('model_path must be provided') + bridge = self._get_bridge() bridge.load_weights(mg_model, model_path, is_peft_format, adapter_name) - + def save_weights( self, mg_models: Union[nn.Module, List[nn.Module]], @@ -1415,12 +1566,12 @@ def save_weights( class TwinkleBridgeInitializer: """ Megatron model initializer. - + This class provides complete model initialization flow including: - Megatron parallel state initialization - Model creation from HuggingFace config - Weight loading using TwinkleGPTBridge - + Example: initializer = TwinkleBridgeInitializer( tp_size=2, @@ -1429,7 +1580,6 @@ class TwinkleBridgeInitializer: ) model = initializer.create_model('Qwen/Qwen2.5-7B-Instruct') """ - def __init__( self, tp_size: int = 1, @@ -1447,7 +1597,7 @@ def __init__( recompute_num_layers: Optional[int] = None, ): """Initialize TwinkleBridgeInitializer. - + Args: tp_size: Tensor parallel size. pp_size: Pipeline parallel size. @@ -1482,34 +1632,34 @@ def __init__( self.recompute_modules = recompute_modules or ['core_attn'] self.recompute_method = recompute_method self.recompute_num_layers = recompute_num_layers - + self._model = None self._bridge = None self._hf_config = None self._model_path = None - + def _download_model(self, model_path: str) -> str: """Download model if it's a model ID.""" if os.path.isdir(model_path): return model_path - + try: from modelscope import snapshot_download return snapshot_download(model_path) except ImportError: from huggingface_hub import snapshot_download return snapshot_download(model_path) - + def _initialize_megatron(self, hf_config: Any = None): """Initialize Megatron parallel state. - + This sets up the required process groups for tensor, pipeline, and data parallelism using Megatron's parallel state module directly. - + Handles both local (torchrun) and Ray execution modes: - Local: Uses torchrun's environment variables (already set) - Ray: Uses RayHelper's environment variables (RANK, WORLD_SIZE, etc.) - + Args: hf_config: Optional HuggingFace config for additional model parameters. """ @@ -1518,17 +1668,17 @@ def _initialize_megatron(self, hf_config: Any = None): from datetime import timedelta from megatron.core import parallel_state as mpu from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed - + # Check if already initialized try: if mpu.is_initialized(): return except AssertionError: pass - + # Determine execution mode twinkle_mode = os.environ.get('TWINKLE_MODE', 'local') - + # Initialize distributed if not already if not dist.is_initialized(): if twinkle_mode == 'ray': @@ -1538,10 +1688,10 @@ def _initialize_megatron(self, hf_config: Any = None): master_addr = os.environ.get('MASTER_ADDR', 'localhost') master_port = os.environ.get('MASTER_PORT', '29500') local_rank = int(os.environ.get('LOCAL_RANK', '0')) - + # Set CUDA device before init_process_group torch.cuda.set_device(local_rank) - + # Initialize process group with explicit parameters dist.init_process_group( backend='nccl', @@ -1553,7 +1703,7 @@ def _initialize_megatron(self, hf_config: Any = None): else: # Local mode (torchrun): environment variables are already set dist.init_process_group(backend='nccl') - + # Initialize Megatron parallel state directly mpu.initialize_model_parallel( tensor_model_parallel_size=self.tp_size, @@ -1561,22 +1711,22 @@ def _initialize_megatron(self, hf_config: Any = None): context_parallel_size=self.cp_size, expert_model_parallel_size=self.ep_size, ) - + # Initialize CUDA RNG tracker for tensor parallel random states # This is required when use_cpu_initialization=False (GPU initialization) model_parallel_cuda_manual_seed(42) - + def _create_model_from_config( - self, - hf_config: Any, + self, + hf_config: Any, padded_vocab_size: int, ) -> nn.Module: """Create Megatron GPT model from HuggingFace config. - + Args: hf_config: HuggingFace model configuration. padded_vocab_size: Padded vocabulary size. - + Returns: Megatron GPT model. """ @@ -1586,22 +1736,22 @@ def _create_model_from_config( from megatron.core.transformer.enums import AttnBackend from megatron.core.models.gpt import GPTModel from megatron.core.models.gpt.gpt_layer_specs import ( - get_gpt_layer_with_transformer_engine_spec, - ) - + get_gpt_layer_with_transformer_engine_spec, ) + # Convert HF config to Megatron config from ..utils import convert_hf_config mg_config_dict = convert_hf_config(hf_config) - + # Build TransformerConfig num_attention_heads = mg_config_dict['num_attention_heads'] - num_query_groups = mg_config_dict.get('num_query_groups', num_attention_heads) + num_query_groups = mg_config_dict.get('num_query_groups', + num_attention_heads) num_layers = mg_config_dict['num_layers'] - + # Configure activation recomputation recompute_method = self.recompute_method recompute_num_layers = self.recompute_num_layers - + # Auto-configure for 'full' recomputation if not specified if self.recompute_granularity == 'full': if recompute_method is None: @@ -1609,93 +1759,116 @@ def _create_model_from_config( if recompute_num_layers is None: # Recompute all layers for maximum memory savings recompute_num_layers = num_layers // self.pp_size - + # Create finalize_model_grads function for DP gradient synchronization # Megatron's native finalize_model_grads requires DDP-wrapped models with ddp_config. # For PEFT/LoRA models, we use a custom implementation that handles non-DDP models. from megatron.core.distributed import finalize_model_grads as _native_finalize_model_grads - - def finalize_model_grads_for_lora(model, num_tokens=None, pg_collection=None): + + def finalize_model_grads_for_lora(model, + num_tokens=None, + pg_collection=None): """Finalize model grads that handles both DDP and PEFT/LoRA models. - + For DDP-wrapped models: Delegates to Megatron's native finalize_model_grads For PEFT/LoRA models: Manually all-reduce gradients across DP ranks - + This is necessary because PEFT models don't have ddp_config attribute that Megatron's native implementation expects. """ from megatron.core import parallel_state as mpu - + # Check if model is DDP-wrapped (has ddp_config) if hasattr(model[0], 'ddp_config'): # Use native implementation for DDP models - return _native_finalize_model_grads(model, num_tokens, pg_collection) - + return _native_finalize_model_grads(model, num_tokens, + pg_collection) + # For PEFT/LoRA models, call finish_grad_sync on each chunk # The model should have finish_grad_sync added by MegatronModel.add_adapter_to_model for model_chunk in model: if hasattr(model_chunk, 'finish_grad_sync'): model_chunk.finish_grad_sync() - + # MoE configuration num_experts = mg_config_dict.get('num_experts', 0) or 0 moe_ffn_hidden_size = mg_config_dict.get('moe_ffn_hidden_size') moe_router_topk = mg_config_dict.get('moe_router_topk', 2) or 2 - moe_shared_expert_intermediate_size = mg_config_dict.get('moe_shared_expert_intermediate_size') - + moe_shared_expert_intermediate_size = mg_config_dict.get( + 'moe_shared_expert_intermediate_size') + # Build MoE-related kwargs moe_kwargs = {} if num_experts > 0: moe_kwargs.update({ - 'num_moe_experts': num_experts, - 'moe_router_topk': moe_router_topk, - 'moe_router_load_balancing_type': mg_config_dict.get('moe_router_load_balancing_type', 'aux_loss'), - # MoE performance optimizations (aligned with Swift defaults) - 'moe_token_dispatcher_type': mg_config_dict.get('moe_token_dispatcher_type', 'alltoall'), # 'alltoall' is more efficient than 'allgather' - 'moe_grouped_gemm': mg_config_dict.get('moe_grouped_gemm', True), # Enable for better performance (requires grouped_gemm package) - 'moe_aux_loss_coeff': mg_config_dict.get('moe_aux_loss_coeff', 0.0), # Auxiliary load balancing loss coefficient + 'num_moe_experts': + num_experts, + 'moe_router_topk': + moe_router_topk, + 'moe_router_load_balancing_type': + mg_config_dict.get('moe_router_load_balancing_type', + 'aux_loss'), + # MoE performance optimizations + 'moe_token_dispatcher_type': + mg_config_dict.get( + 'moe_token_dispatcher_type', 'alltoall' + ), # 'alltoall' is more efficient than 'allgather' + 'moe_grouped_gemm': + mg_config_dict.get( + 'moe_grouped_gemm', True + ), # Enable for better performance (requires grouped_gemm package) + 'moe_aux_loss_coeff': + mg_config_dict.get( + 'moe_aux_loss_coeff', + 0.0), # Auxiliary load balancing loss coefficient }) - + # FFN hidden size for MoE if moe_ffn_hidden_size: moe_kwargs['moe_ffn_hidden_size'] = moe_ffn_hidden_size - + # Shared expert configuration if moe_shared_expert_intermediate_size: - moe_kwargs['moe_shared_expert_intermediate_size'] = moe_shared_expert_intermediate_size - + moe_kwargs[ + 'moe_shared_expert_intermediate_size'] = moe_shared_expert_intermediate_size + # Router score function (sigmoid for Qwen3, softmax for others) if mg_config_dict.get('moe_router_score_function'): - moe_kwargs['moe_router_score_function'] = mg_config_dict['moe_router_score_function'] - + moe_kwargs['moe_router_score_function'] = mg_config_dict[ + 'moe_router_score_function'] + # Expert bias for sigmoid router if mg_config_dict.get('moe_router_enable_expert_bias'): - moe_kwargs['moe_router_enable_expert_bias'] = mg_config_dict['moe_router_enable_expert_bias'] - + moe_kwargs['moe_router_enable_expert_bias'] = mg_config_dict[ + 'moe_router_enable_expert_bias'] + # Sequence parallel requires TP > 1 # Auto-enable for MoE with TP > 1 (required by Megatron) use_sequence_parallel = self.sequence_parallel and self.tp_size > 1 if num_experts > 0 and self.tp_size > 1 and not use_sequence_parallel: use_sequence_parallel = True - print(f"Auto-enabling sequence_parallel for MoE with TP={self.tp_size}") - + print( + f'Auto-enabling sequence_parallel for MoE with TP={self.tp_size}' + ) + # For MoE models, ffn_hidden_size should be moe_ffn_hidden_size if not specified ffn_hidden_size = mg_config_dict.get('ffn_hidden_size') if ffn_hidden_size is None: - ffn_hidden_size = moe_ffn_hidden_size or (4 * mg_config_dict['hidden_size']) - + ffn_hidden_size = moe_ffn_hidden_size or ( + 4 * mg_config_dict['hidden_size']) + # For models with non-standard head dimensions (like Qwen3-30B-A3B) kv_channels = mg_config_dict.get('kv_channels') - + # Activation function for SwiGLU (required by Megatron when gated_linear_unit=True) use_swiglu = mg_config_dict.get('swiglu', True) activation_func = torch.nn.functional.silu if use_swiglu else torch.nn.functional.gelu - - # Enable bias_activation_fusion for SwiGLU (same as Swift) + + # Enable bias_activation_fusion for SwiGLU # Note: Only works with TransformerEngine and no bias in linear layers has_bias = not mg_config_dict.get('disable_bias_linear', True) bias_activation_fusion = use_swiglu and not has_bias - + config = TransformerConfig( num_layers=num_layers, hidden_size=mg_config_dict['hidden_size'], @@ -1709,13 +1882,16 @@ def finalize_model_grads_for_lora(model, num_tokens=None, pg_collection=None): expert_model_parallel_size=self.ep_size, sequence_parallel=use_sequence_parallel, params_dtype=self.params_dtype, - pipeline_dtype=self.params_dtype, # Required when using pipeline parallelism + pipeline_dtype=self. + params_dtype, # Required when using pipeline parallelism use_cpu_initialization=self.use_cpu_initialization, add_qkv_bias=mg_config_dict.get('add_qkv_bias', False), - add_bias_linear=not mg_config_dict.get('disable_bias_linear', True), + add_bias_linear=not mg_config_dict.get('disable_bias_linear', + True), gated_linear_unit=use_swiglu, activation_func=activation_func, # SiLU for SwiGLU, GELU otherwise - bias_activation_fusion=bias_activation_fusion, # Fused SwiGLU for performance + bias_activation_fusion= + bias_activation_fusion, # Fused SwiGLU for performance normalization='RMSNorm', layernorm_epsilon=mg_config_dict.get('norm_epsilon', 1e-6), qk_layernorm=mg_config_dict.get('qk_layernorm', False), @@ -1729,7 +1905,8 @@ def finalize_model_grads_for_lora(model, num_tokens=None, pg_collection=None): attention_backend=AttnBackend.flash, # FlashAttention for speed # Activation recomputation for memory efficiency recompute_granularity=self.recompute_granularity, - recompute_modules=self.recompute_modules if self.recompute_granularity == 'selective' else None, + recompute_modules=self.recompute_modules + if self.recompute_granularity == 'selective' else None, recompute_method=recompute_method, recompute_num_layers=recompute_num_layers, # Critical: Set finalize_model_grads_func for DP gradient synchronization @@ -1738,10 +1915,10 @@ def finalize_model_grads_for_lora(model, num_tokens=None, pg_collection=None): # MoE configuration **moe_kwargs, ) - + # Save transformer config for later use (e.g., DDP wrapping) self._transformer_config = config - + # Get layer spec - enable moe_grouped_gemm for MoE models moe_grouped_gemm = num_experts > 0 try: @@ -1757,11 +1934,11 @@ def finalize_model_grads_for_lora(model, num_tokens=None, pg_collection=None): moe_grouped_gemm=moe_grouped_gemm, qk_layernorm=mg_config_dict.get('qk_layernorm', False), ) - + # Create model max_seq_length = getattr(hf_config, 'max_position_embeddings', 4096) rotary_base = mg_config_dict.get('rotary_base', 10000) - + model = GPTModel( config=config, transformer_layer_spec=layer_spec, @@ -1770,50 +1947,53 @@ def finalize_model_grads_for_lora(model, num_tokens=None, pg_collection=None): pre_process=mpu.is_pipeline_first_stage(), post_process=mpu.is_pipeline_last_stage(), parallel_output=True, - share_embeddings_and_output_weights=getattr(hf_config, 'tie_word_embeddings', False), + share_embeddings_and_output_weights=getattr( + hf_config, 'tie_word_embeddings', False), position_embedding_type='rope', rotary_base=rotary_base, ) - + return model - + def _pad_vocab_size(self, vocab_size: int) -> int: """Pad vocab size for tensor parallelism.""" divisor = self.tp_size * 128 return ((vocab_size + divisor - 1) // divisor) * divisor - + def create_model( self, model_path: str, load_weights: bool = True, ) -> nn.Module: """Create Megatron model from HuggingFace checkpoint. - + Args: model_path: Path to HuggingFace model or model ID. load_weights: Whether to load weights. - + Returns: Megatron model. """ from transformers import AutoConfig - + # Download model if needed model_path = HubOperation.download_model(model_path) self._model_path = model_path - + # Load HF config first (needed for initialization) - self._hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) - + self._hf_config = AutoConfig.from_pretrained(model_path, + trust_remote_code=True) + # Initialize Megatron parallel state with hf_config for proper args setup self._initialize_megatron(self._hf_config) - + # Calculate padded vocab size padded_vocab_size = self._pad_vocab_size(self._hf_config.vocab_size) - + # Create model - self._model = self._create_model_from_config(self._hf_config, padded_vocab_size) - + self._model = self._create_model_from_config(self._hf_config, + padded_vocab_size) + # Load weights if load_weights: bridge_adapter = TwinkleBridgeAdapter( @@ -1827,35 +2007,35 @@ def create_model( ) bridge_adapter.load_weights(self._model, model_path) self._bridge = bridge_adapter._get_bridge() - + # Synchronize all ranks after model creation and weight loading # This is critical for Pipeline Parallel to ensure all ranks are ready # before any collective communication operations if dist.is_initialized(): dist.barrier() - + return self._model - + @property def hf_config(self): """Get the HuggingFace config.""" return self._hf_config - + @property def bridge(self): """Get the bridge instance.""" return self._bridge - + def load_weights(self, model: nn.Module, model_path: str): """Load weights into an existing model. - + Args: model: Megatron model. model_path: Path to HuggingFace checkpoint. """ if self._bridge is None and self._hf_config is None: - raise ValueError("Must call create_model first") - + raise ValueError('Must call create_model first') + padded_vocab_size = self._pad_vocab_size(self._hf_config.vocab_size) bridge_adapter = TwinkleBridgeAdapter( hf_config=self._hf_config, @@ -1866,22 +2046,27 @@ def load_weights(self, model: nn.Module, model_path: str): padded_vocab_size=padded_vocab_size, ) bridge_adapter.load_weights(model, model_path) - - def save_weights(self, models: Union[nn.Module, List[nn.Module]], output_dir: str, is_peft_format: bool = False): + + def save_weights(self, + models: Union[nn.Module, List[nn.Module]], + output_dir: str, + is_peft_format: bool = False): """Save weights in HuggingFace format. - + Args: models: Megatron model(s). output_dir: Output directory. is_peft_format: Whether to save in PEFT format. """ if self._bridge is None: - raise ValueError("Must load weights first") - + raise ValueError('Must load weights first') + if not isinstance(models, (list, tuple)): models = [models] - - self._bridge.save_weights(models, output_dir, is_peft_format=is_peft_format) + + self._bridge.save_weights(models, + output_dir, + is_peft_format=is_peft_format) # Legacy functions for backward compatibility @@ -1903,9 +2088,11 @@ def restore_megatron_args() -> None: def mock_megatron_args(args: SimpleNamespace): """Legacy function - no longer needed with TwinkleGPTBridge.""" from contextlib import contextmanager + @contextmanager def noop(): yield args + return noop() diff --git a/src/twinkle/megatron/model/initializer.py b/src/twinkle/megatron/model/initializer.py index ec3e9e1c..76b7b976 100644 --- a/src/twinkle/megatron/model/initializer.py +++ b/src/twinkle/megatron/model/initializer.py @@ -4,24 +4,25 @@ from typing import Any, Dict, Optional, Type import torch -import torch.nn as nn import torch.distributed as dist +import torch.nn as nn +from packaging import version # Direct imports - assume megatron is installed import megatron.core from megatron.core import parallel_state as mpu -from megatron.core.transformer import TransformerConfig from megatron.core.models.gpt import GPTModel -from packaging import version - -mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') +from megatron.core.transformer import TransformerConfig from ..utils import convert_hf_config +mcore_013 = version.parse( + megatron.core.__version__) >= version.parse('0.13.0rc0') + def _get_transformer_config_fields() -> set: """Get valid field names for TransformerConfig. - + Returns: Set of valid field names. """ @@ -30,13 +31,12 @@ def _get_transformer_config_fields() -> set: class MegatronModelInitializer: """Initialize Megatron-Core models from HuggingFace checkpoints. - + This class handles: - Converting HuggingFace config to Megatron TransformerConfig - Creating Megatron model architecture - Loading HuggingFace weights into Megatron model """ - def __init__( self, tp_size: int = 1, @@ -50,7 +50,7 @@ def __init__( use_cpu_initialization: bool = True, ): """Initialize MegatronModelInitializer. - + Args: tp_size: Tensor parallel size. pp_size: Pipeline parallel size. @@ -71,30 +71,30 @@ def __init__( self.sequence_parallel = sequence_parallel self.params_dtype = params_dtype self.use_cpu_initialization = use_cpu_initialization - + # Cache valid TransformerConfig fields self._valid_config_fields = _get_transformer_config_fields() - + def create_transformer_config( self, hf_config: Any, **overrides, ) -> 'TransformerConfig': """Create Megatron TransformerConfig from HuggingFace config. - + Args: hf_config: HuggingFace model config. **overrides: Config overrides. - + Returns: Megatron TransformerConfig. """ # Convert HuggingFace config to dict mg_config_dict = convert_hf_config(hf_config) - + # Apply overrides mg_config_dict.update(overrides) - + # Build config kwargs with only valid fields config_kwargs = { # Required fields @@ -110,36 +110,50 @@ def create_transformer_config( 'params_dtype': self.params_dtype, 'use_cpu_initialization': self.use_cpu_initialization, } - + # Optional fields - only add if valid for this Megatron version optional_fields = { - 'num_query_groups': mg_config_dict.get('num_query_groups', mg_config_dict['num_attention_heads']), - 'ffn_hidden_size': mg_config_dict.get('ffn_hidden_size', 4 * mg_config_dict['hidden_size']), - 'num_moe_experts': mg_config_dict.get('num_experts'), - 'moe_router_topk': mg_config_dict.get('moe_router_topk', 2) if mg_config_dict.get('num_experts') else None, - 'layernorm_epsilon': mg_config_dict.get('norm_epsilon', 1e-6), - 'add_qkv_bias': mg_config_dict.get('add_qkv_bias', False), - 'add_bias_linear': not mg_config_dict.get('disable_bias_linear', True), - 'gated_linear_unit': mg_config_dict.get('swiglu', True), - 'qk_layernorm': mg_config_dict.get('qk_layernorm', False), - 'normalization': 'RMSNorm', + 'num_query_groups': + mg_config_dict.get('num_query_groups', + mg_config_dict['num_attention_heads']), + 'ffn_hidden_size': + mg_config_dict.get('ffn_hidden_size', + 4 * mg_config_dict['hidden_size']), + 'num_moe_experts': + mg_config_dict.get('num_experts'), + 'moe_router_topk': + mg_config_dict.get('moe_router_topk', 2) + if mg_config_dict.get('num_experts') else None, + 'layernorm_epsilon': + mg_config_dict.get('norm_epsilon', 1e-6), + 'add_qkv_bias': + mg_config_dict.get('add_qkv_bias', False), + 'add_bias_linear': + not mg_config_dict.get('disable_bias_linear', True), + 'gated_linear_unit': + mg_config_dict.get('swiglu', True), + 'qk_layernorm': + mg_config_dict.get('qk_layernorm', False), + 'normalization': + 'RMSNorm', } - + # Add optional fields that are valid for this Megatron version for key, value in optional_fields.items(): if key in self._valid_config_fields and value is not None: config_kwargs[key] = value - + # Store rotary settings for GPTModel (not TransformerConfig) self._rotary_base = mg_config_dict.get('rotary_base', 10000) self._rotary_percent = mg_config_dict.get('rotary_percent', 1.0) - self._position_embedding_type = mg_config_dict.get('position_embedding_type', 'rope') - + self._position_embedding_type = mg_config_dict.get( + 'position_embedding_type', 'rope') + # Create TransformerConfig config = TransformerConfig(**config_kwargs) - + return config - + def create_gpt_model( self, hf_config: Any, @@ -148,33 +162,34 @@ def create_gpt_model( **config_overrides, ) -> 'GPTModel': """Create Megatron GPT model from HuggingFace config. - + Args: hf_config: HuggingFace model config. vocab_size: Override vocab size. max_sequence_length: Override max sequence length. **config_overrides: Config overrides. - + Returns: Megatron GPTModel. """ # Create config (also sets self._rotary_base, etc.) config = self.create_transformer_config(hf_config, **config_overrides) - + # Get vocab size if vocab_size is None: vocab_size = hf_config.vocab_size - + # Pad vocab size for tensor parallelism padded_vocab_size = self._pad_vocab_size(vocab_size) - + # Get max sequence length if max_sequence_length is None: - max_sequence_length = getattr(hf_config, 'max_position_embeddings', 4096) - + max_sequence_length = getattr(hf_config, 'max_position_embeddings', + 4096) + # Get tie_word_embeddings setting tie_word_embeddings = getattr(hf_config, 'tie_word_embeddings', False) - + # Create model with rotary settings passed directly to GPTModel model = GPTModel( config=config, @@ -189,28 +204,28 @@ def create_gpt_model( rotary_percent=self._rotary_percent, rotary_base=self._rotary_base, ) - + return model - + def _pad_vocab_size(self, vocab_size: int) -> int: """Pad vocab size for tensor parallelism. - + Args: vocab_size: Original vocab size. - + Returns: Padded vocab size. """ # Pad to multiple of tp_size * 128 for efficient parallelism divisor = self.tp_size * 128 return ((vocab_size + divisor - 1) // divisor) * divisor - + def _get_layer_spec(self, config: 'TransformerConfig'): """Get transformer layer specification. - + Args: config: Transformer config. - + Returns: Layer specification (ModuleSpec or TransformerBlockSubmodules). """ @@ -218,13 +233,14 @@ def _get_layer_spec(self, config: 'TransformerConfig'): get_gpt_layer_with_transformer_engine_spec, get_gpt_layer_local_spec, ) - + # Determine if this is a MoE model num_experts = getattr(config, 'num_moe_experts', None) moe_grouped_gemm = getattr(config, 'moe_grouped_gemm', False) qk_layernorm = getattr(config, 'qk_layernorm', False) - multi_latent_attention = getattr(config, 'multi_latent_attention', False) - + multi_latent_attention = getattr(config, 'multi_latent_attention', + False) + # Try TE (TransformerEngine) layers first for better performance try: return get_gpt_layer_with_transformer_engine_spec( @@ -241,7 +257,7 @@ def _get_layer_spec(self, config: 'TransformerConfig'): qk_layernorm=qk_layernorm, multi_latent_attention=multi_latent_attention, ) - + def load_from_hf( self, model: nn.Module, @@ -249,14 +265,14 @@ def load_from_hf( hf_config: Any, ) -> None: """Load HuggingFace checkpoint into Megatron model. - + Args: model: The Megatron model. hf_model_path: Path to HuggingFace checkpoint or model ID. hf_config: HuggingFace model config. """ import os - + # Resolve model path if it's a model ID (not a local path) if not os.path.isdir(hf_model_path): from twinkle.hub import HubOperation @@ -264,7 +280,7 @@ def load_from_hf( # Calculate padded vocab size padded_vocab_size = self._pad_vocab_size(hf_config.vocab_size) - + # Use TwinkleBridgeAdapter from .bridge import TwinkleBridgeAdapter adapter = TwinkleBridgeAdapter( @@ -276,7 +292,7 @@ def load_from_hf( padded_vocab_size=padded_vocab_size, ) adapter.load_weights(model, hf_model_path) - + def initialize_megatron_model( hf_model_path: str, @@ -288,7 +304,7 @@ def initialize_megatron_model( load_weights: bool = True, ) -> nn.Module: """Convenience function to initialize Megatron model from HuggingFace checkpoint. - + Args: hf_model_path: Path to HuggingFace checkpoint. tp_size: Tensor parallel size. @@ -297,15 +313,15 @@ def initialize_megatron_model( ep_size: Expert parallel size. params_dtype: Parameter data type. load_weights: Whether to load weights. - + Returns: Initialized Megatron model. """ from transformers import AutoConfig - + # Load HuggingFace config hf_config = AutoConfig.from_pretrained(hf_model_path) - + # Create initializer initializer = MegatronModelInitializer( tp_size=tp_size, @@ -314,13 +330,12 @@ def initialize_megatron_model( ep_size=ep_size, params_dtype=params_dtype, ) - + # Create model model = initializer.create_gpt_model(hf_config) - + # Load weights if load_weights: initializer.load_from_hf(model, hf_model_path, hf_config) - - return model + return model diff --git a/src/twinkle/megatron/model/multi_tenant_megatron.py b/src/twinkle/megatron/model/multi_tenant_megatron.py index e280151f..f56b37d2 100644 --- a/src/twinkle/megatron/model/multi_tenant_megatron.py +++ b/src/twinkle/megatron/model/multi_tenant_megatron.py @@ -2,32 +2,29 @@ """ Multi-Tenant Megatron Model for LoRA training. -This module provides multi-tenant LoRA training support for Megatron models, -similar to MultiLoraTransformersModel but optimized for Megatron's architecture. - -Key features: -1. Uses MultiAdapter's ContextVar mechanism for tenant isolation -2. Integrates with Megatron's parallel state and DDP -3. Supports per-tenant optimizers, schedulers, and gradient accumulation -4. Compatible with Swift Megatron's LoraParallelLinear +This module integrates TenantManager and MultiTenantLoRADDP to provide +a complete multi-tenant training solution. """ import contextvars import logging import re from contextlib import contextmanager -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Set, Type, Union +from typing import Any, Dict, List, Optional, Type import torch import torch.distributed as dist import torch.nn as nn +from ..distributed.multi_tenant_ddp import MultiTenantLoRADDP +from ..distributed.tenant_context import (get_current_tenant, require_tenant, + set_current_tenant, tenant_scope) +from ..distributed.tenant_manager import TenantManager, TenantState + logger = logging.getLogger(__name__) try: from megatron.core import parallel_state as mpu - from megatron.core.distributed import DistributedDataParallel as MegatronDDP from megatron.core.distributed.distributed_data_parallel_config import DistributedDataParallelConfig from megatron.core.transformer.transformer_config import TransformerConfig MEGATRON_AVAILABLE = True @@ -35,7 +32,6 @@ MEGATRON_AVAILABLE = False try: - from peft import LoraConfig, PeftModel from peft.tuners.lora import LoraLayer, LoraModel PEFT_AVAILABLE = True except ImportError: @@ -44,474 +40,294 @@ class MegatronMultiAdapter: """ - Megatron-compatible MultiAdapter using ContextVar for tenant isolation. - - This patches LoraLayer/LoraModel to use ContextVar-based adapter selection, - enabling thread/coroutine-safe multi-tenant training. - - Key difference from twinkle's MultiAdapter: - - Also patches Swift Megatron's LoraParallelLinear if present + Patches LoRA layers to use ContextVar-based adapter selection. + + This enables thread-safe multi-tenant training where each tenant's + active adapter is determined by the current context. """ - - _adapter_var: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar( - 'megatron_adapter_name', default=None - ) + + _adapter_var: contextvars.ContextVar[ + Optional[str]] = contextvars.ContextVar('adapter_names', default=None) _patched: bool = False - + def __call__(self, module: nn.Module) -> nn.Module: - """ - Patch LoRA layers to use ContextVar-based adapter selection. - - Args: - module: Model containing LoRA layers. - - Returns: - Patched model (same instance, modified in-place). - """ + """Patch LoRA layers.""" if MegatronMultiAdapter._patched: return module - + self._patch_peft_lora() - self._patch_megatron_lora() - + self._patch_twinkle_lora() + module.set_current_adapter_name = MegatronMultiAdapter.set_current_adapter_name MegatronMultiAdapter._patched = True - + return module - + def _patch_peft_lora(self): - """Patch PEFT's LoraLayer and LoraModel.""" + """Patch PEFT's LoraLayer/LoraModel.""" if not PEFT_AVAILABLE: return - - def get_active_adapter(*args, **kwargs): + + if getattr(LoraLayer, '_patched', False): + return + + def get_active_adapter(*args): return MegatronMultiAdapter._adapter_var.get() - - def get_active_adapters(*args, **kwargs): - adapter_name = MegatronMultiAdapter._adapter_var.get() - return [adapter_name] if adapter_name else [] - - def set_active_adapters(_, value): - pass # Controlled via ContextVar - - def set_adapter(self, adapter_names): - pass # Controlled via ContextVar - - def mark_only_adapters_trainable(self, model) -> None: - for n, p in model.named_parameters(): - p.requires_grad = "lora_" in n - - # Patch LoraLayer - LoraLayer.active_adapter = property(get_active_adapter, set_active_adapters) - LoraLayer.active_adapters = property(get_active_adapters, set_active_adapters) - LoraLayer.set_adapter = set_adapter - - # Patch LoraModel - LoraModel.active_adapter = property(get_active_adapter, set_active_adapters) - LoraModel.active_adapters = property(get_active_adapters, set_active_adapters) - LoraModel.set_adapter = set_adapter - LoraModel._mark_only_adapters_as_trainable = mark_only_adapters_trainable - - logger.info("Patched PEFT LoraLayer/LoraModel for multi-tenant support") - - def _patch_megatron_lora(self): - """Patch Swift Megatron's LoraParallelLinear if available.""" + + def get_active_adapters(*args): + adapter = MegatronMultiAdapter._adapter_var.get() + return [adapter] if adapter else [] + + LoraLayer.active_adapter = property(get_active_adapter) + LoraLayer.active_adapters = property(get_active_adapters) + LoraLayer.set_adapter = lambda self, x: None + LoraLayer._patched = True + + LoraModel.active_adapter = property(get_active_adapter) + LoraModel.active_adapters = property(get_active_adapters) + LoraModel.set_adapter = lambda self, x: None + LoraModel._patched = True + + logger.info('Patched PEFT LoraLayer/LoraModel') + + def _patch_twinkle_lora(self): + """Patch Twinkle's LoraParallelLinear.""" try: - from swift.megatron.tuners.lora import LoraParallelLinear - + from twinkle.megatron.tuners.lora import LoraParallelLinear + if hasattr(LoraParallelLinear, '_patched'): + return + def get_active_adapter(self): return MegatronMultiAdapter._adapter_var.get() - + def get_active_adapters(self): - adapter_name = MegatronMultiAdapter._adapter_var.get() - return [adapter_name] if adapter_name else [] - - # Patch as properties - if not hasattr(LoraParallelLinear, '_multi_tenant_patched'): - LoraParallelLinear.active_adapter = property(get_active_adapter) - LoraParallelLinear.active_adapters = property(get_active_adapters) - LoraParallelLinear._multi_tenant_patched = True - logger.info("Patched LoraParallelLinear for multi-tenant support") + adapter = MegatronMultiAdapter._adapter_var.get() + return [adapter] if adapter else [] + + LoraParallelLinear.active_adapter = property(get_active_adapter) + LoraParallelLinear.active_adapters = property(get_active_adapters) + LoraParallelLinear._patched = True + logger.info('Patched LoraParallelLinear') except ImportError: - logger.debug("Swift Megatron LoraParallelLinear not available") - + pass + @staticmethod - def set_current_adapter_name(adapter_name: Optional[str]): - """Set the current adapter for this context.""" - MegatronMultiAdapter._adapter_var.set(adapter_name) - + def set_current_adapter_name(name: Optional[str]): + """Set current adapter.""" + MegatronMultiAdapter._adapter_var.set(name) + @staticmethod def get_current_adapter_name() -> Optional[str]: - """Get the current adapter name.""" + """Get current adapter.""" return MegatronMultiAdapter._adapter_var.get() -@dataclass -class TenantState: - """State for a single tenant.""" - adapter_name: str - process_group: Optional[dist.ProcessGroup] = None - optimizer: Optional[torch.optim.Optimizer] = None - scheduler: Optional[Any] = None - grad_scaler: Optional[torch.cuda.amp.GradScaler] = None - lora_config: Optional['LoraConfig'] = None - - # Tracking - trainable_params: List[nn.Parameter] = field(default_factory=list) - param_names: Dict[nn.Parameter, str] = field(default_factory=dict) +class MultiTenantMegatronModel(nn.Module): + """ + Multi-tenant Megatron model wrapper. + Combines: + - TenantManager: Tenant lifecycle (adapters, optimizers) + - MultiTenantLoRADDP: Per-tenant gradient sync + - MegatronMultiAdapter: Context-based adapter selection -class MultiTenantMegatronModel: - """ - Multi-Tenant Megatron Model wrapper for LoRA training. - - This class provides: - 1. Multi-tenant adapter management using ContextVar - 2. Per-tenant optimizer and scheduler - 3. Gradient synchronization with tenant-specific process groups - 4. Integration with Megatron's DDP - - Design: - - Uses a single Megatron DDP wrapper for all tenants - - Each tenant has isolated LoRA adapters - - ContextVar ensures thread-safe adapter switching - Example: - >>> model = create_megatron_model(...) - >>> multi_tenant = MultiTenantMegatronModel(model, config, ddp_config) - >>> - >>> # Add tenants - >>> multi_tenant.add_tenant('user_a', lora_config_a) - >>> multi_tenant.add_tenant('user_b', lora_config_b) - >>> - >>> # Training - >>> with multi_tenant.tenant_context('user_a'): - ... output = multi_tenant(input) - ... loss.backward() - ... multi_tenant.step() + >>> model = MultiTenantMegatronModel(base_model, config) + >>> + >>> # Initialize tenant (creates adapter, buffers, optimizer) + >>> tenant_id = model.initialize(lora_config=LoraConfig(r=8)) + >>> + >>> # Training (uses current tenant automatically) + >>> model.zero_grad() + >>> output = model(input) + >>> loss = compute_loss(output) + >>> model.backward(loss) + >>> model.finish_grad_sync() + >>> model.step() + >>> + >>> # Cleanup + >>> model.finalize() """ - - LORA_PARAM_PATTERN = re.compile(r'\.lora_\w+\.[^.]+\.') - def __init__( self, model: nn.Module, config: 'TransformerConfig', ddp_config: Optional['DistributedDataParallelConfig'] = None, - default_dp_group: Optional[dist.ProcessGroup] = None, ): """ - Initialize multi-tenant model. - + Initialize. + Args: - model: Base Megatron model (can be already wrapped with PEFT). - config: Transformer configuration. - ddp_config: DDP configuration. If None, creates default. - default_dp_group: Default data parallel group for tenants. + model: Base model with LoRA structure. + config: Transformer config. + ddp_config: DDP config. """ + super().__init__() + if not MEGATRON_AVAILABLE: - raise ImportError("Megatron-Core is required") - + raise ImportError('Megatron-Core required') + self.config = config self.ddp_config = ddp_config or DistributedDataParallelConfig( overlap_grad_reduce=True, use_distributed_optimizer=False, ) - - # Setup multi-adapter + + # Patch LoRA layers for multi-tenant self._multi_adapter = MegatronMultiAdapter() self.model = self._multi_adapter(model) - - # Tenant management - self._tenants: Dict[str, TenantState] = {} - self._default_dp_group = default_dp_group or mpu.get_data_parallel_group( - with_context_parallel=True - ) - - # DDP wrapper (created lazily after first tenant is added) - self._ddp: Optional[MegatronDDP] = None - - # Add a dummy adapter to ensure PEFT model structure is ready - self._ensure_peft_model() - - def _ensure_peft_model(self): - """Ensure the model is a PEFT model.""" - if not PEFT_AVAILABLE: - logger.warning("PEFT not available, skipping PEFT model check") - return - - if not isinstance(self.model, PeftModel): - # Create minimal LoRA config for structure - dummy_config = LoraConfig( - r=1, - target_modules='all-linear', - init_lora_weights=False, - ) - # Note: For Megatron models, you typically use Swift's prepare_model - logger.warning( - "Model is not a PeftModel. For Megatron LoRA, " - "use Swift.prepare_model() before wrapping." - ) - - def _wrap_with_ddp(self): - """Wrap model with Megatron DDP (lazy initialization).""" - if self._ddp is not None: - return - - self._ddp = MegatronDDP( + + # Create DDP + self._ddp = MultiTenantLoRADDP( config=self.config, ddp_config=self.ddp_config, module=self.model, ) - logger.info( - f"Created Megatron DDP with {len(self._ddp.params_with_grad)} params, " - f"{len(self._ddp.bucket_groups)} bucket groups" - ) - - def add_tenant( - self, - tenant_id: str, - lora_config: Optional['LoraConfig'] = None, - process_group: Optional[dist.ProcessGroup] = None, - optimizer_cls: Type[torch.optim.Optimizer] = torch.optim.AdamW, - optimizer_kwargs: Optional[Dict[str, Any]] = None, - scheduler_cls: Optional[Type] = None, - scheduler_kwargs: Optional[Dict[str, Any]] = None, - ): - """ - Add a tenant with their LoRA configuration. - - Args: - tenant_id: Unique tenant identifier. - lora_config: LoRA configuration. If None, assumes adapter already exists. - process_group: Custom process group for this tenant's gradient sync. - optimizer_cls: Optimizer class. - optimizer_kwargs: Optimizer arguments. - scheduler_cls: LR scheduler class. - scheduler_kwargs: Scheduler arguments. - """ - if tenant_id in self._tenants: - logger.warning(f"Tenant '{tenant_id}' already exists, skipping") - return - - adapter_name = tenant_id - - # Add adapter if config provided and using PEFT - if lora_config is not None and PEFT_AVAILABLE and isinstance(self.model, PeftModel): - # Safety checks - lora_config.modules_to_save = None - lora_config.bias = 'none' - - self.model.add_adapter(adapter_name, lora_config) - logger.info(f"Added LoRA adapter '{adapter_name}'") - - # Set adapter as active to find its params - MegatronMultiAdapter.set_current_adapter_name(adapter_name) - - # Find trainable params for this adapter - trainable_params = [] - param_names = {} - - for name, param in self.model.named_parameters(): - if self.LORA_PARAM_PATTERN.search(name) and f'.{adapter_name}.' in name: - param.requires_grad = True - trainable_params.append(param) - param_names[param] = name - - # Create tenant state - state = TenantState( - adapter_name=adapter_name, - process_group=process_group or self._default_dp_group, - lora_config=lora_config, - trainable_params=trainable_params, - param_names=param_names, + + # Create tenant manager + self._manager = TenantManager( + model=self.model, + default_process_group=mpu.get_data_parallel_group( + with_context_parallel=True), ) - - # Create optimizer - if optimizer_kwargs is None: - optimizer_kwargs = {'lr': 1e-4, 'weight_decay': 0.01} - - state.optimizer = optimizer_cls(trainable_params, **optimizer_kwargs) - - # Create scheduler if specified - if scheduler_cls is not None: - scheduler_kwargs = scheduler_kwargs or {} - state.scheduler = scheduler_cls(state.optimizer, **scheduler_kwargs) - - self._tenants[tenant_id] = state - - logger.info( - f"Registered tenant '{tenant_id}' with {len(trainable_params)} " - f"trainable params ({sum(p.numel() for p in trainable_params):,} elements)" + + # Wire up callbacks + self._manager.register_add_callback(self._on_tenant_added) + self._manager.register_remove_callback(self._on_tenant_removed) + + logger.info('MultiTenantMegatronModel initialized') + + def _on_tenant_added(self, state: TenantState): + """Called when tenant is added via manager.""" + self._ddp.add_tenant( + tenant_id=state.tenant_id, + params=state.params, + process_group=state.process_group, + param_names=state.param_names, ) - - # Reset adapter context - MegatronMultiAdapter.set_current_adapter_name(None) - - def remove_tenant(self, tenant_id: str): - """Remove a tenant.""" - if tenant_id not in self._tenants: - logger.warning(f"Tenant '{tenant_id}' not found") - return - - state = self._tenants.pop(tenant_id) - - # Remove adapter from model if using PEFT - if PEFT_AVAILABLE and isinstance(self.model, PeftModel): - try: - self.model.delete_adapter(state.adapter_name) - except Exception as e: - logger.warning(f"Failed to delete adapter: {e}") - - logger.info(f"Removed tenant '{tenant_id}'") - - @contextmanager - def tenant_context(self, tenant_id: str): - """ - Context manager for tenant-specific operations. - - All forward/backward operations within this context will use - the specified tenant's LoRA adapter. - """ - if tenant_id not in self._tenants: - raise ValueError(f"Tenant '{tenant_id}' not registered") - - state = self._tenants[tenant_id] - prev_adapter = MegatronMultiAdapter.get_current_adapter_name() - - try: - MegatronMultiAdapter.set_current_adapter_name(state.adapter_name) - yield state - finally: - MegatronMultiAdapter.set_current_adapter_name(prev_adapter) - - def forward(self, *args, tenant_id: Optional[str] = None, **kwargs): - """ - Forward pass with tenant selection. - - Args: - *args: Model inputs. - tenant_id: Tenant to use. If None, uses current context. - **kwargs: Additional arguments. - """ - if tenant_id is not None: - MegatronMultiAdapter.set_current_adapter_name(tenant_id) - - # Ensure DDP is initialized - if self._ddp is None: - self._wrap_with_ddp() - + + def _on_tenant_removed(self, state: TenantState): + """Called when tenant is removed via manager.""" + if self._ddp.has_tenant(state.tenant_id): + self._ddp.remove_tenant(state.tenant_id) + + def forward(self, *args, **kwargs): + """Forward pass.""" return self._ddp(*args, **kwargs) - - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - - def backward(self, loss: torch.Tensor, tenant_id: Optional[str] = None): - """ - Backward pass with optional tenant selection. - - Args: - loss: Loss tensor. - tenant_id: Tenant for gradient accumulation. - """ - if tenant_id is not None: - MegatronMultiAdapter.set_current_adapter_name(tenant_id) - - loss.backward() - - # Sync gradients for this tenant - self._reduce_tenant_gradients(tenant_id) - - def _reduce_tenant_gradients(self, tenant_id: Optional[str] = None): - """ - Reduce gradients for a specific tenant. - - For now, uses Megatron DDP's finish_grad_sync which syncs all params. - A more optimized version could filter to only tenant's params. - """ - tenant_id = tenant_id or MegatronMultiAdapter.get_current_adapter_name() - - if self._ddp is not None: - self._ddp.finish_grad_sync() - - def step(self, tenant_id: Optional[str] = None): + + # ========== Tenant Lifecycle ========== + + def initialize(self, **kwargs) -> str: """ - Optimizer step for a tenant. - + Initialize a tenant. + Args: - tenant_id: Tenant to update. If None, uses current context. + **kwargs: Passed to TenantManager.initialize() + + Returns: + Tenant ID. """ - tenant_id = tenant_id or MegatronMultiAdapter.get_current_adapter_name() - - if tenant_id is None: - raise ValueError("No tenant specified and no current tenant context") - - state = self._tenants.get(tenant_id) - if state is None: - raise ValueError(f"Tenant '{tenant_id}' not registered") - - if state.optimizer is not None: - state.optimizer.step() - + return self._manager.initialize(**kwargs) + + def finalize(self, tenant_id: Optional[str] = None): + """Finalize a tenant.""" + self._manager.finalize(tenant_id) + + @contextmanager + def scope(self, tenant_id: Optional[str] = None): + """Context manager for tenant scope.""" + with self._manager.scope(tenant_id) as state: + # Also set adapter + MegatronMultiAdapter.set_current_adapter_name(state.adapter_name) + try: + yield state + finally: + MegatronMultiAdapter.set_current_adapter_name(None) + + # ========== Training Operations ========== + def zero_grad(self, tenant_id: Optional[str] = None): - """Zero gradients for a tenant.""" - tenant_id = tenant_id or MegatronMultiAdapter.get_current_adapter_name() - - if tenant_id is None: - # Zero all - if self._ddp is not None: - self._ddp.zero_grad_buffer() - return - - state = self._tenants.get(tenant_id) - if state is not None and state.optimizer is not None: - state.optimizer.zero_grad() - - def lr_step(self, tenant_id: Optional[str] = None): - """LR scheduler step for a tenant.""" - tenant_id = tenant_id or MegatronMultiAdapter.get_current_adapter_name() - - if tenant_id is None: - return - - state = self._tenants.get(tenant_id) - if state is not None and state.scheduler is not None: - state.scheduler.step() - + """Zero gradients.""" + tenant_id = tenant_id or require_tenant() + state = self._manager.get(tenant_id) + + self._ddp.zero_grad_buffer(tenant_id) + if state.optimizer: + state.optimizer.zero_grad(set_to_none=True) + + def backward(self, loss: torch.Tensor, tenant_id: Optional[str] = None): + """Backward pass.""" + tenant_id = tenant_id or require_tenant() + state = self._manager.get(tenant_id) + + MegatronMultiAdapter.set_current_adapter_name(state.adapter_name) + scaled_loss = loss / state.gradient_accumulation_steps + scaled_loss.backward() + + @contextmanager + def no_sync(self, tenant_id: Optional[str] = None): + """Disable gradient sync.""" + with self._ddp.no_sync(tenant_id): + yield + + def finish_grad_sync(self, tenant_id: Optional[str] = None): + """Finish gradient sync.""" + self._ddp.finish_grad_sync(tenant_id) + def clip_grad_norm( self, - max_norm: float = 1.0, - norm_type: float = 2.0, + max_norm: Optional[float] = None, tenant_id: Optional[str] = None, ) -> torch.Tensor: - """Clip gradients for a tenant.""" - tenant_id = tenant_id or MegatronMultiAdapter.get_current_adapter_name() - - if tenant_id is None: - raise ValueError("No tenant specified") - - state = self._tenants.get(tenant_id) - if state is None: - raise ValueError(f"Tenant '{tenant_id}' not registered") - - return torch.nn.utils.clip_grad_norm_( - state.trainable_params, max_norm, norm_type - ) - - def get_tenant_state(self, tenant_id: str) -> Optional[TenantState]: - """Get state for a tenant.""" - return self._tenants.get(tenant_id) - - def list_tenants(self) -> List[str]: - """List all registered tenants.""" - return list(self._tenants.keys()) - + """Clip gradients.""" + tenant_id = tenant_id or require_tenant() + state = self._manager.get(tenant_id) + max_norm = max_norm or state.max_grad_norm + return torch.nn.utils.clip_grad_norm_(state.params, max_norm) + + def step(self, tenant_id: Optional[str] = None): + """Optimizer step.""" + tenant_id = tenant_id or require_tenant() + state = self._manager.get(tenant_id) + if state.optimizer: + state.optimizer.step() + + def lr_step(self, tenant_id: Optional[str] = None): + """LR scheduler step.""" + tenant_id = tenant_id or require_tenant() + state = self._manager.get(tenant_id) + if state.scheduler: + state.scheduler.step() + + def get_lr(self, tenant_id: Optional[str] = None) -> Optional[float]: + """Get current LR.""" + tenant_id = tenant_id or require_tenant() + state = self._manager.get(tenant_id) + if state.optimizer: + return state.optimizer.param_groups[0]['lr'] + return None + + # ========== Utilities ========== + + def tenant_count(self) -> int: + """Get number of active tenants.""" + return self._manager.count() + + def has_tenant(self, tenant_id: str) -> bool: + """Check if a specific tenant exists.""" + return self._manager.has(tenant_id) + @property - def ddp(self) -> Optional[MegatronDDP]: - """Get the DDP wrapper.""" + def ddp(self) -> MultiTenantLoRADDP: + """Get DDP wrapper.""" return self._ddp - + + @property + def manager(self) -> TenantManager: + """Get tenant manager.""" + return self._manager + @property def unwrapped_model(self) -> nn.Module: - """Get the unwrapped model.""" + """Get unwrapped model.""" return self.model diff --git a/src/twinkle/megatron/model/qwen3.py b/src/twinkle/megatron/model/qwen3.py index b87b79ae..0acd3f67 100644 --- a/src/twinkle/megatron/model/qwen3.py +++ b/src/twinkle/megatron/model/qwen3.py @@ -11,18 +11,20 @@ # ============================================================================= class Qwen3ModelMeta: """Metadata for Qwen3 models.""" - + # Supported architectures - DENSE_ARCHITECTURES = ['Qwen3ForCausalLM', 'Qwen2ForCausalLM', 'Qwen2.5ForCausalLM'] + DENSE_ARCHITECTURES = [ + 'Qwen3ForCausalLM', 'Qwen2ForCausalLM', 'Qwen2.5ForCausalLM' + ] MOE_ARCHITECTURES = ['Qwen3MoeForCausalLM', 'Qwen2MoeForCausalLM'] ALL_ARCHITECTURES = DENSE_ARCHITECTURES + MOE_ARCHITECTURES - + # HuggingFace key prefixes HF_LAYERS_PREFIX = 'model.layers' HF_EMBED_KEY = 'model.embed_tokens.weight' HF_FINAL_LAYERNORM_KEY = 'model.norm.weight' HF_LM_HEAD_KEY = 'lm_head.weight' - + # Qwen3 specific settings DEFAULT_CONFIG = { 'qk_layernorm': True, @@ -30,17 +32,17 @@ class Qwen3ModelMeta: 'disable_bias_linear': True, 'rotary_interleaved': False, } - + # MoE specific settings MOE_CONFIG = { 'use_shared_expert_gate': True, } - + @classmethod def is_qwen3(cls, architecture: str) -> bool: """Check if architecture is a Qwen3 model.""" return architecture in cls.ALL_ARCHITECTURES - + @classmethod def is_qwen3_moe(cls, architecture: str) -> bool: """Check if architecture is a Qwen3 MoE model.""" @@ -49,10 +51,10 @@ def is_qwen3_moe(cls, architecture: str) -> bool: def get_model_default_config(architecture: str) -> Dict[str, Any]: """Get default config overrides for a model architecture. - + Args: architecture: Model architecture name. - + Returns: Default config dict for Megatron TransformerConfig. """ diff --git a/src/twinkle/megatron/tuners/lora.py b/src/twinkle/megatron/tuners/lora.py index dbaefe7a..22075f7f 100644 --- a/src/twinkle/megatron/tuners/lora.py +++ b/src/twinkle/megatron/tuners/lora.py @@ -8,46 +8,42 @@ import torch import torch.nn as nn import torch.nn.functional as F +from packaging import version +from peft.tuners.lora import model +from peft.tuners.lora.layer import LoraLayer +from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge +from peft.utils.other import transpose # Direct imports - assume megatron and peft are installed import megatron.core from megatron.core import parallel_state from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.extensions.transformer_engine import ( - TEColumnParallelGroupedLinear, TEColumnParallelLinear, - TEGroupedLinear, TELayerNormColumnParallelLinear, TELinear, - TERowParallelGroupedLinear, TERowParallelLinear -) -from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding + TEColumnParallelGroupedLinear, TEColumnParallelLinear, TEGroupedLinear, + TELayerNormColumnParallelLinear, TELinear, TERowParallelGroupedLinear, + TERowParallelLinear) +from megatron.core.models.common.embeddings.language_model_embedding import \ + LanguageModelEmbedding from megatron.core.parallel_state import ( - get_expert_tensor_parallel_world_size, - get_tensor_model_parallel_world_size -) + get_expert_tensor_parallel_world_size, + get_tensor_model_parallel_world_size) from megatron.core.tensor_parallel import ( - gather_from_sequence_parallel_region, - scatter_to_sequence_parallel_region -) + gather_from_sequence_parallel_region, scatter_to_sequence_parallel_region) from megatron.core.transformer.mlp import apply_swiglu_sharded_factory from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.moe.router import TopKRouter -from packaging import version -from peft.tuners.lora import model -from peft.tuners.lora.layer import LoraLayer -from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge -from peft.utils.other import transpose - -mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') +mcore_013 = version.parse( + megatron.core.__version__) >= version.parse('0.13.0rc0') class LoraParallelLinear(MegatronModule, LoraLayer): """LoRA layer compatible with Megatron Tensor Parallel Linear layers. - - This class wraps Megatron's parallel linear layers (TELinear, TEColumnParallelLinear, + + This class wraps Megatron's parallel linear layers (TELinear, TEColumnParallelLinear, TERowParallelLinear, etc.) and adds LoRA adapters that are correctly sharded across tensor parallel ranks. """ - def __init__( self, base_layer, @@ -63,7 +59,7 @@ def __init__( **kwargs, ): """Initialize LoraParallelLinear. - + Args: base_layer: The Megatron parallel linear layer to wrap. adapter_name: Name of the LoRA adapter. @@ -83,20 +79,24 @@ def __init__( LoraLayer.__init__(self, base_layer=base_layer) if use_dora: - raise ValueError(f'{self.__class__.__name__} does not support DoRA yet, please set it to False') + raise ValueError( + f'{self.__class__.__name__} does not support DoRA yet, please set it to False' + ) - self.is_parallel_a = isinstance(base_layer, (TERowParallelLinear, TERowParallelGroupedLinear)) + self.is_parallel_a = isinstance( + base_layer, (TERowParallelLinear, TERowParallelGroupedLinear)) self.is_grouped = isinstance(base_layer, TEGroupedLinear) self.fan_in_fan_out = fan_in_fan_out self._active_adapter = adapter_name self.is_expert = getattr(base_layer, 'is_expert', False) - self.sequence_parallel = getattr(base_layer, 'sequence_parallel', False) - + self.sequence_parallel = getattr(base_layer, 'sequence_parallel', + False) + if self.is_expert: self.tp_size = get_expert_tensor_parallel_world_size() else: self.tp_size = get_tensor_model_parallel_world_size() - + self.update_layer( adapter_name, r, @@ -109,20 +109,11 @@ def __init__( self.is_target_conv_1d_layer = False - def update_layer( - self, - adapter_name: str, - r: int, - *, - lora_alpha: int, - lora_dropout: float, - init_lora_weights: bool, - use_rslora: bool, - lora_bias: bool, - **kwargs - ): + def update_layer(self, adapter_name: str, r: int, *, lora_alpha: int, + lora_dropout: float, init_lora_weights: bool, + use_rslora: bool, lora_bias: bool, **kwargs): """Update LoRA layer with new adapter configuration. - + Args: adapter_name: Name of the adapter. r: LoRA rank. @@ -133,11 +124,13 @@ def update_layer( lora_bias: Whether to add bias. """ if r <= 0: - raise ValueError(f'`r` should be a positive integer value but the value passed is {r}') - + raise ValueError( + f'`r` should be a positive integer value but the value passed is {r}' + ) + self.r[adapter_name] = r self.lora_alpha[adapter_name] = lora_alpha - + if lora_dropout > 0.0: lora_dropout_layer = nn.Dropout(p=lora_dropout) else: @@ -154,7 +147,7 @@ def update_layer( } if mcore_013: kwargs['tp_group'] = self.base_layer.tp_group - + if isinstance(self.base_layer, TopKRouter): # Router layer - no parallelism needed router_shape = self.base_layer.weight.shape @@ -214,14 +207,12 @@ def update_layer( # Column parallel layer - LoRA A is not parallel, LoRA B is parallel out_features = self.out_features * self.tp_size if self.is_grouped: - lora_a = TEGroupedLinear( - num_gemms=self.base_layer.num_gemms, - input_size=self.in_features, - output_size=r, - bias=lora_bias, - parallel_mode=None, - **kwargs - ) + lora_a = TEGroupedLinear(num_gemms=self.base_layer.num_gemms, + input_size=self.in_features, + output_size=r, + bias=lora_bias, + parallel_mode=None, + **kwargs) lora_b = TEColumnParallelGroupedLinear( num_gemms=self.base_layer.num_gemms, input_size=r, @@ -230,14 +221,12 @@ def update_layer( **kwargs, ) else: - lora_a = TELinear( - input_size=self.in_features, - output_size=r, - bias=lora_bias, - parallel_mode=None, - skip_weight_param_allocation=False, - **kwargs - ) + lora_a = TELinear(input_size=self.in_features, + output_size=r, + bias=lora_bias, + parallel_mode=None, + skip_weight_param_allocation=False, + **kwargs) lora_b = TEColumnParallelLinear( input_size=r, output_size=out_features, @@ -249,35 +238,39 @@ def update_layer( # Disable overlap for LoRA layers for lora in [lora_a, lora_b]: - if isinstance(lora, (TERowParallelLinear, TEColumnParallelLinear)) and lora.parallel_mode is None: + if isinstance( + lora, + (TERowParallelLinear, + TEColumnParallelLinear)) and lora.parallel_mode is None: lora.ub_overlap_rs_fprop = False lora.ub_overlap_ag_dgrad = False lora.ub_overlap_ag_fprop = False lora.ub_overlap_rs_dgrad = False - + lora_a.sequence_parallel = False lora_b.sequence_parallel = False - + self.lora_A[adapter_name] = lora_a self.lora_B[adapter_name] = lora_b - + if hasattr(self, 'lora_bias'): self.lora_bias[adapter_name] = lora_bias - + if use_rslora: - self.scaling[adapter_name] = lora_alpha / (r ** 0.5) + self.scaling[adapter_name] = lora_alpha / (r**0.5) else: self.scaling[adapter_name] = lora_alpha / r - + if init_lora_weights: self.reset_lora_parameters(adapter_name, init_lora_weights) self._move_adapter_to_device_of_base_layer(adapter_name) self.set_adapter(self.active_adapters) - def reset_lora_parameters(self, adapter_name: str, init_lora_weights: bool): + def reset_lora_parameters(self, adapter_name: str, + init_lora_weights: bool): """Reset LoRA parameters to initial values. - + Args: adapter_name: Name of the adapter. init_lora_weights: Initialization method. @@ -288,28 +281,35 @@ def reset_lora_parameters(self, adapter_name: str, init_lora_weights: bool): if adapter_name in self.lora_A.keys(): lora_a = self.lora_A[adapter_name] lora_b = self.lora_B[adapter_name] - + if isinstance(lora_a, TEGroupedLinear): - weights_a = [getattr(lora_a, f'weight{i}') for i in range(lora_a.num_gemms)] + weights_a = [ + getattr(lora_a, f'weight{i}') + for i in range(lora_a.num_gemms) + ] else: weights_a = [lora_a.weight] - + if isinstance(lora_b, TEGroupedLinear): - weights_b = [getattr(lora_b, f'weight{i}') for i in range(lora_b.num_gemms)] + weights_b = [ + getattr(lora_b, f'weight{i}') + for i in range(lora_b.num_gemms) + ] else: weights_b = [lora_b.weight] - + for weight_a in weights_a: if init_lora_weights is True: nn.init.kaiming_uniform_(weight_a, a=math.sqrt(5)) elif init_lora_weights.lower() == 'gaussian': nn.init.normal_(weight_a, std=1 / self.r[adapter_name]) else: - raise ValueError(f'Unknown initialization {init_lora_weights=}') - + raise ValueError( + f'Unknown initialization {init_lora_weights=}') + for weight_b in weights_b: nn.init.zeros_(weight_b) - + if adapter_name in self.lora_embedding_A.keys(): nn.init.zeros_(self.lora_embedding_A[adapter_name]) nn.init.normal_(self.lora_embedding_B[adapter_name]) @@ -330,10 +330,12 @@ def gating(_self, x): scaling = self.scaling[active_adapter] x = x.to(result.dtype) - lora_result = F.linear(dropout(x), lora_A.weight.to(result.dtype)) + lora_result = F.linear(dropout(x), + lora_A.weight.to(result.dtype)) if isinstance(lora_result, tuple): lora_result = lora_result[0] - lora_result = F.linear(lora_result, lora_B.weight.to(result.dtype)) + lora_result = F.linear(lora_result, + lora_B.weight.to(result.dtype)) if isinstance(lora_result, tuple): lora_result = lora_result[0] lora_result = lora_result * scaling @@ -349,12 +351,12 @@ def gating(_self, x): def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any): """Forward pass with LoRA adaptation. - + Args: x: Input tensor. *args: Additional positional arguments. **kwargs: Additional keyword arguments. - + Returns: Tuple of (output tensor, bias). """ @@ -375,36 +377,45 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any): with self._patch_router_gating(): result, bias = self.base_layer(x, *args, **kwargs) else: - raise ValueError(f'Unsupported base layer type: {type(self.base_layer)}') - - if not isinstance(self.base_layer, TopKRouter) and not self.disable_adapters and not self.merged: + raise ValueError( + f'Unsupported base layer type: {type(self.base_layer)}') + + if not isinstance( + self.base_layer, + TopKRouter) and not self.disable_adapters and not self.merged: if self.sequence_parallel and self.base_layer.parallel_mode == 'column': x = gather_from_sequence_parallel_region(x) - + for active_adapter in self.active_adapters: if active_adapter not in self.lora_A.keys(): continue - + lora_A = self.lora_A[active_adapter] lora_B = self.lora_B[active_adapter] dropout = self.lora_dropout[active_adapter] scaling = self.scaling[active_adapter] - dtype = lora_A.weight0.dtype if isinstance(lora_A, TEGroupedLinear) else lora_A.weight.dtype + dtype = lora_A.weight0.dtype if isinstance( + lora_A, TEGroupedLinear) else lora_A.weight.dtype x = x.to(dtype) - lora_result = lora_A(dropout(x), *args, **kwargs) if isinstance(lora_A, TEGroupedLinear) else lora_A(dropout(x)) + lora_result = lora_A( + dropout(x), *args, **kwargs) if isinstance( + lora_A, TEGroupedLinear) else lora_A(dropout(x)) if isinstance(lora_result, tuple): lora_result = lora_result[0] - - lora_result = lora_B(lora_result, *args, **kwargs) if isinstance(lora_B, TEGroupedLinear) else lora_B(lora_result) + + lora_result = lora_B( + lora_result, *args, **kwargs) if isinstance( + lora_B, TEGroupedLinear) else lora_B(lora_result) if isinstance(lora_result, tuple): lora_result = lora_result[0] - + lora_result = lora_result * scaling - + if self.sequence_parallel and self.base_layer.parallel_mode == 'row': - lora_result = scatter_to_sequence_parallel_region(lora_result) - + lora_result = scatter_to_sequence_parallel_region( + lora_result) + result = result + lora_result result = result.to(previous_dtype) @@ -417,44 +428,51 @@ def sharded_state_dict( metadata: Optional[dict] = None, ) -> ShardedStateDict: """Get sharded state dict for distributed checkpointing. - + Args: prefix: Key prefix. sharded_offsets: Sharding offsets. metadata: Additional metadata. - + Returns: Sharded state dictionary. """ from ..utils import tuners_sharded_state_dict - - sharded_state_dict = tuners_sharded_state_dict(self, prefix, sharded_offsets, metadata) - + + sharded_state_dict = tuners_sharded_state_dict(self, prefix, + sharded_offsets, + metadata) + if prefix.endswith('linear_fc1.'): - if isinstance(self.base_layer, TEGroupedLinear) and self.config.gated_linear_unit: + if isinstance(self.base_layer, + TEGroupedLinear) and self.config.gated_linear_unit: num_global_experts = ( - parallel_state.get_expert_model_parallel_world_size() * self.base_layer.num_gemms - ) + parallel_state.get_expert_model_parallel_world_size() * + self.base_layer.num_gemms) local_expert_indices_offset = ( - parallel_state.get_expert_model_parallel_rank() * self.base_layer.num_gemms - ) + parallel_state.get_expert_model_parallel_rank() * + self.base_layer.num_gemms) ep_axis = len(sharded_offsets) for i in range(self.base_layer.num_gemms): new_sharded_offsets = ( *sharded_offsets, - (ep_axis, local_expert_indices_offset + i, num_global_experts), + (ep_axis, local_expert_indices_offset + i, + num_global_experts), ) - for k in (f'{prefix}base_layer.weight{i}', f'{prefix}base_layer.bias{i}'): + for k in (f'{prefix}base_layer.weight{i}', + f'{prefix}base_layer.bias{i}'): if k in sharded_state_dict: - sharded_state_dict[k] = apply_swiglu_sharded_factory( - sharded_state_dict[k], new_sharded_offsets - ) + sharded_state_dict[ + k] = apply_swiglu_sharded_factory( + sharded_state_dict[k], new_sharded_offsets) else: for k, v in sharded_state_dict.items(): - if k in [f'{prefix}base_layer.weight', f'{prefix}base_layer.bias']: + if k in [ + f'{prefix}base_layer.weight', + f'{prefix}base_layer.bias' + ]: sharded_state_dict[k] = apply_swiglu_sharded_factory( - sharded_state_dict[k], sharded_offsets - ) + sharded_state_dict[k], sharded_offsets) return sharded_state_dict def get_delta_weights(self, adapter: str) -> List[torch.Tensor]: @@ -462,31 +480,37 @@ def get_delta_weights(self, adapter: str) -> List[torch.Tensor]: Args: adapter: The name of the adapter. - + Returns: List of delta weight tensors. """ lora_A = self.lora_A[adapter] lora_B = self.lora_B[adapter] - + if self.is_grouped: - weight_A = [getattr(lora_A, f'weight{i}') for i in range(lora_A.num_gemms)] - weight_B = [getattr(lora_B, f'weight{i}') for i in range(lora_B.num_gemms)] + weight_A = [ + getattr(lora_A, f'weight{i}') for i in range(lora_A.num_gemms) + ] + weight_B = [ + getattr(lora_B, f'weight{i}') for i in range(lora_B.num_gemms) + ] else: weight_A = [self.lora_A[adapter].weight] weight_B = [self.lora_B[adapter].weight] - + output_tensor = [] assert len(weight_A) == len(weight_B) - + for i in range(len(weight_B)): output_tensor.append( - transpose(weight_B[i] @ weight_A[i], self.fan_in_fan_out) * self.scaling[adapter] - ) + transpose(weight_B[i] @ weight_A[i], self.fan_in_fan_out) * + self.scaling[adapter]) return output_tensor - def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + def merge(self, + safe_merge: bool = False, + adapter_names: Optional[list[str]] = None) -> None: """Merge the active adapter weights into the base weights. Args: @@ -499,24 +523,33 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N base_layer = self.get_base_layer() origin_device = base_layer.weight0.device if self.is_grouped else base_layer.weight.device - + if origin_device.type == 'cpu': - device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu' + device = torch.cuda.current_device() if torch.cuda.is_available( + ) else 'cpu' self.to(device=device) - + for active_adapter in adapter_names: if active_adapter in self.lora_A.keys(): if self.is_grouped: - orig_weights = [getattr(base_layer, f'weight{i}') for i in range(base_layer.num_gemms)] + orig_weights = [ + getattr(base_layer, f'weight{i}') + for i in range(base_layer.num_gemms) + ] else: orig_weights = [base_layer.weight] - + if safe_merge: - orig_weights = [weight.data.clone() for weight in orig_weights] + orig_weights = [ + weight.data.clone() for weight in orig_weights + ] delta_weights = self.get_delta_weights(active_adapter) - for orig_weight, delta_weight in zip(orig_weights, delta_weights): + for orig_weight, delta_weight in zip( + orig_weights, delta_weights): orig_weight += delta_weight - if not all(torch.isfinite(orig_weights[i]).all() for i in range(len(orig_weights))): + if not all( + torch.isfinite(orig_weights[i]).all() + for i in range(len(orig_weights))): raise ValueError( f'NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken' ) @@ -528,11 +561,12 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N base_layer.weight.data = orig_weights[0] else: delta_weights = self.get_delta_weights(active_adapter) - for orig_weight, delta_weight in zip(orig_weights, delta_weights): + for orig_weight, delta_weight in zip( + orig_weights, delta_weights): orig_weight.data += delta_weight - + self.merged_adapters.append(active_adapter) - + if origin_device.type == 'cpu': self.to(device=origin_device) @@ -543,20 +577,25 @@ def unmerge(self) -> None: base_layer = self.get_base_layer() origin_device = base_layer.weight0.device if self.is_grouped else base_layer.weight.device - + if origin_device.type == 'cpu': - device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu' + device = torch.cuda.current_device() if torch.cuda.is_available( + ) else 'cpu' self.to(device=device) for active_adapter in self.merged_adapters: if active_adapter in self.lora_A.keys(): if self.is_grouped: - orig_weights = [getattr(base_layer, f'weight{i}') for i in range(base_layer.num_gemms)] + orig_weights = [ + getattr(base_layer, f'weight{i}') + for i in range(base_layer.num_gemms) + ] else: orig_weights = [base_layer.weight] delta_weights = self.get_delta_weights(active_adapter) - for orig_weight, delta_weight in zip(orig_weights, delta_weights): + for orig_weight, delta_weight in zip(orig_weights, + delta_weights): orig_weight.data -= delta_weight self.merged_adapters = [] @@ -572,13 +611,13 @@ def dispatch_megatron( **kwargs: Any, ) -> Optional[torch.nn.Module]: """Dispatch function to replace Megatron linear layers with LoRA layers. - + Args: target: The target module to potentially replace. adapter_name: Name of the LoRA adapter. lora_config: LoRA configuration. **kwargs: Additional arguments for LoraParallelLinear. - + Returns: LoraParallelLinear if target is a compatible layer, None otherwise. """ @@ -589,9 +628,12 @@ def dispatch_megatron( else: target_base_layer = target - linear_cls = (TELayerNormColumnParallelLinear, TELinear, TEGroupedLinear, TopKRouter) + linear_cls = (TELayerNormColumnParallelLinear, TELinear, TEGroupedLinear, + TopKRouter) if isinstance(target_base_layer, linear_cls): - new_module = LoraParallelLinear(base_layer=target, adapter_name=adapter_name, **kwargs) + new_module = LoraParallelLinear(base_layer=target, + adapter_name=adapter_name, + **kwargs) return new_module @@ -601,4 +643,3 @@ def dispatch_megatron( model.dispatch_megatron = dispatch_megatron except Exception: pass - diff --git a/src/twinkle/megatron/utils.py b/src/twinkle/megatron/utils.py index 68e5663e..7f465230 100644 --- a/src/twinkle/megatron/utils.py +++ b/src/twinkle/megatron/utils.py @@ -430,10 +430,6 @@ def get_tenant_manager() -> TenantProcessGroupManager: _tenant_manager = TenantProcessGroupManager() return _tenant_manager - -# ============================================================================= - -# ============================================================================= def find_layers(model: nn.Module, cond_fn) -> List[str]: """Find all layers in model matching condition function. diff --git a/src/twinkle/model/megatron.py b/src/twinkle/model/megatron.py index e4cf71d2..08f3af06 100644 --- a/src/twinkle/model/megatron.py +++ b/src/twinkle/model/megatron.py @@ -7,19 +7,20 @@ from typing import Any, Dict, List, Literal, Optional, Type, Union import torch -import torch.nn as nn import torch.distributed as dist +import torch.nn as nn from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler import twinkle -from twinkle import remote_class, remote_function, template, DeviceMesh +from twinkle import DeviceMesh, remote_class, remote_function, template from twinkle.data_format import InputFeature, Trajectory from twinkle.hub import HubOperation from twinkle.loss import Loss, MegatronCrossEntropyLoss from twinkle.processor import InputProcessor from twinkle.template import Template from twinkle.utils.plugin import Plugin + from .base import TwinkleModel from .strategy import MegatronStrategy @@ -29,7 +30,8 @@ from megatron.core.distributed import DistributedDataParallel as MegatronDDP from packaging import version MEGATRON_AVAILABLE = True - mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') + mcore_013 = version.parse( + megatron.core.__version__) >= version.parse('0.13.0rc0') except ImportError: MEGATRON_AVAILABLE = False mcore_013 = False @@ -38,7 +40,7 @@ @dataclass class MegatronOptimizerGroup: """Optimizer group for Megatron training. - + Similar to OptimizerGroup but adapted for Megatron's distributed training. """ adapter_name: str = None @@ -59,7 +61,9 @@ class MegatronOptimizerGroup: _last_grad_norm: float = 0.0 _last_step_success: bool = True - def do_grad_sync(self, gradient_accumulation_steps: Optional[int] = None) -> bool: + def do_grad_sync(self, + gradient_accumulation_steps: Optional[int] = None + ) -> bool: """Check if gradient synchronization should happen.""" if gradient_accumulation_steps is None: gradient_accumulation_steps = self.gradient_accumulation_steps @@ -73,22 +77,21 @@ def check_megatron_available(): """Check if Megatron-Core is available.""" if not MEGATRON_AVAILABLE: raise ImportError( - "Megatron-Core is not installed. Please install it with: " - "pip install megatron-core" - ) + 'Megatron-Core is not installed. Please install it with: ' + 'pip install megatron-core') @remote_class(execute='all') class MegatronModel(TwinkleModel, nn.Module): """Megatron-Core model wrapper for twinkle training framework. - + Note: Uses execute='all' to create workers on all ranks, which is required for Megatron's TP/DP parallelism where all ranks must participate in collective operations like gradient all-reduce. - + This class provides a similar API to TransformersModel but uses Megatron-Core as the training backend, supporting TP/PP/CP/EP parallelism. - + Args: pretrained_model_name_or_path: HuggingFace model path or ID. device_mesh: Twinkle DeviceMesh for distributed training. @@ -101,7 +104,6 @@ class MegatronModel(TwinkleModel, nn.Module): use_distributed_optimizer: Use Megatron's distributed optimizer. **kwargs: Additional arguments passed to model initialization. """ - def __init__( self, pretrained_model_name_or_path: str, @@ -114,28 +116,30 @@ def __init__( mixed_precision: Literal['no', 'fp16', 'bf16'] = 'bf16', use_distributed_optimizer: bool = True, load_weights: bool = True, - use_megatron_bridge: bool = True, # Use bridge-based initialization (recommended) - recompute_granularity: Optional[str] = 'selective', # Activation checkpointing + use_megatron_bridge: + bool = True, # Use bridge-based initialization (recommended) + recompute_granularity: Optional[ + str] = 'selective', # Activation checkpointing recompute_modules: Optional[list] = None, # Modules to recompute **kwargs, ): check_megatron_available() nn.Module.__init__(self) - + self.model_id = pretrained_model_name_or_path self.device_mesh = device_mesh self.mixed_precision = mixed_precision self.use_megatron_bridge = use_megatron_bridge self.recompute_granularity = recompute_granularity self.recompute_modules = recompute_modules - + # Load HuggingFace config first model_path = HubOperation.download_model(pretrained_model_name_or_path) self._load_hf_config(model_path) - + # Store model_path for later use self._model_path = model_path - + # Create Megatron strategy self.strategy = MegatronStrategy( tensor_model_parallel_size=tensor_model_parallel_size, @@ -146,25 +150,27 @@ def __init__( use_distributed_optimizer=use_distributed_optimizer, mixed_precision=mixed_precision, ) - + # Initialize parallel state (skip if using bridge init, as it handles this) if not use_megatron_bridge: self.strategy.initialize() - + # Create Megatron model - self.model = self._create_megatron_model(model_path, load_weights, **kwargs) - + self.model = self._create_megatron_model(model_path, load_weights, + **kwargs) + self._model_wrapped = False # This correctly handles vocab sharding in Tensor Parallelism self.optimizer_group: Dict[str, MegatronOptimizerGroup] = { - _default_adapter_name: MegatronOptimizerGroup(loss_instance=MegatronCrossEntropyLoss()) + _default_adapter_name: + MegatronOptimizerGroup(loss_instance=MegatronCrossEntropyLoss()) } - + def _load_hf_config(self, model_path: str): """Load HuggingFace model config.""" from transformers import AutoConfig self.hf_config = AutoConfig.from_pretrained(model_path) - + def _create_megatron_model( self, model_path: str, @@ -172,12 +178,12 @@ def _create_megatron_model( **kwargs, ) -> nn.Module: """Create Megatron model from HuggingFace checkpoint. - + Args: model_path: Path to HuggingFace model. load_weights: Whether to load weights. **kwargs: Additional arguments. - + Returns: Megatron model on GPU. """ @@ -186,15 +192,17 @@ def _create_megatron_model( params_dtype = torch.float16 elif self.mixed_precision == 'no': params_dtype = torch.float32 - + if self.use_megatron_bridge: # Use bridge-based initialization (recommended) # This ensures all patches are applied and config is correctly generated - return self._create_megatron_model_with_bridge(model_path, load_weights, params_dtype, **kwargs) + return self._create_megatron_model_with_bridge( + model_path, load_weights, params_dtype, **kwargs) else: # Use twinkle's native initialization - return self._create_megatron_model_native(model_path, load_weights, params_dtype, **kwargs) - + return self._create_megatron_model_native(model_path, load_weights, + params_dtype, **kwargs) + def _create_megatron_model_with_bridge( self, model_path: str, @@ -203,25 +211,25 @@ def _create_megatron_model_with_bridge( **kwargs, ) -> nn.Module: """Create Megatron model using bridge-based initialization flow. - + This approach uses TwinkleBridgeInitializer for independent initialization It includes: - Proper config conversion from HuggingFace to Megatron format - Correct Megatron initialization (initialize_megatron) - Correct model creation - Weight loading with TwinkleGPTBridge - + Args: model_path: Path to HuggingFace model. load_weights: Whether to load weights. params_dtype: Parameter dtype. **kwargs: Additional arguments. - + Returns: Megatron model on GPU. """ from twinkle.megatron.model.bridge import TwinkleBridgeInitializer - + # Create bridge-based initializer self._bridge_initializer = TwinkleBridgeInitializer( tp_size=self.strategy.tp_size, @@ -234,25 +242,27 @@ def _create_megatron_model_with_bridge( sequence_parallel=self.strategy.sequence_parallel, recompute_granularity=self.recompute_granularity, recompute_modules=self.recompute_modules, - recompute_method=getattr(self, "recompute_method", None), - recompute_num_layers=getattr(self, "recompute_num_layers", None), + recompute_method=getattr(self, 'recompute_method', None), + recompute_num_layers=getattr(self, 'recompute_num_layers', None), ) - + # Create model (this calls initialize_megatron internally) - model = self._bridge_initializer.create_model(model_path, load_weights=load_weights) - + model = self._bridge_initializer.create_model( + model_path, load_weights=load_weights) + # Update strategy state since bridge has initialized Megatron self.strategy._initialized = True self.strategy._parallel_state = mpu - + # Save transformer config for DDP wrapping - self._transformer_config = getattr(self._bridge_initializer, '_transformer_config', None) - + self._transformer_config = getattr(self._bridge_initializer, + '_transformer_config', None) + # Move to GPU model = self._move_model_to_gpu(model) - + return model - + def _create_megatron_model_native( self, model_path: str, @@ -261,20 +271,20 @@ def _create_megatron_model_native( **kwargs, ) -> nn.Module: """Create Megatron model using twinkle's native initialization. - + This is the fallback method when bridge is not available. - + Args: model_path: Path to HuggingFace model. load_weights: Whether to load weights. params_dtype: Parameter dtype. **kwargs: Additional arguments. - + Returns: Megatron model on GPU. """ from twinkle.megatron.model.initializer import MegatronModelInitializer - + initializer = MegatronModelInitializer( tp_size=self.strategy.tp_size, pp_size=self.strategy.pp_size, @@ -283,43 +293,44 @@ def _create_megatron_model_native( sequence_parallel=self.strategy.sequence_parallel, params_dtype=params_dtype, ) - + # Create model model = initializer.create_gpt_model(self.hf_config, **kwargs) - + # Load weights if load_weights: initializer.load_from_hf(model, model_path, self.hf_config) - + model = self._move_model_to_gpu(model) - + return model - + def _move_model_to_gpu(self, model: nn.Module) -> nn.Module: """Move model to correct GPU device. - + This method handles moving parameters, buffers, and any cached tensors (like RoPE embeddings) to the correct device for distributed training. """ # Determine the target device based on local rank - local_rank = dist.get_rank() % torch.cuda.device_count() if dist.is_initialized() else 0 + local_rank = dist.get_rank() % torch.cuda.device_count( + ) if dist.is_initialized() else 0 device = torch.device(f'cuda:{local_rank}') - + # Set CUDA device explicitly torch.cuda.set_device(local_rank) - + # Move all parameters and buffers to GPU model = model.to(device) - + # Force synchronize to ensure all transfers complete if torch.cuda.is_available(): torch.cuda.synchronize(device) - + return model - + def _lazy_wrap_model(self): """Lazily wrap model with distributed wrapper. - + Note: This should only be called after prepare_training() has been executed on all workers. Direct calls from forward() may cause deadlocks if not all DP ranks are participating. @@ -328,9 +339,10 @@ def _lazy_wrap_model(self): # Find an optimizer from any adapter group (prefer default, then first available) optimizer = None optimizer_adapter = None - + if _default_adapter_name in self.optimizer_group: - optimizer = self.optimizer_group[_default_adapter_name].optimizer + optimizer = self.optimizer_group[ + _default_adapter_name].optimizer optimizer_adapter = _default_adapter_name else: for name, group in self.optimizer_group.items(): @@ -338,16 +350,17 @@ def _lazy_wrap_model(self): optimizer = group.optimizer optimizer_adapter = name break - + if optimizer is not None: - self.model, optimizer = self.strategy.wrap_model(self.model, optimizer) + self.model, optimizer = self.strategy.wrap_model( + self.model, optimizer) self.optimizer_group[optimizer_adapter].optimizer = optimizer self._model_wrapped = True - + @remote_function(dispatch='all') def prepare_training(self, **kwargs): """Prepare model for training. - + Note: In Ray-based Megatron training, we skip DDP wrapping to avoid deadlocks from collective operations. Each DP replica trains independently. This method still calls _lazy_wrap_model for any non-DDP setup needed. @@ -355,20 +368,22 @@ def prepare_training(self, **kwargs): self._lazy_wrap_model() @remote_function() - def forward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs): + def forward(self, *, inputs: Union[InputFeature, List[InputFeature], + Trajectory, List[Trajectory]], + **kwargs): """Forward pass with Megatron model. - + Args: inputs: Model inputs. **kwargs: Additional arguments including adapter_name. - + Returns: Model outputs. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] self._lazy_wrap_model() - + # Encode inputs if needed if isinstance(inputs, dict) and 'input_ids' not in inputs: if optimizer_config.template is not None: @@ -376,33 +391,33 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, if isinstance(inputs, list) and 'input_ids' not in inputs[0]: if optimizer_config.template is not None: inputs = optimizer_config.template.batch_encode(inputs) - + # Process inputs processor: InputProcessor = optimizer_config.processor if processor is not None: inputs: Dict[str, Any] = processor(inputs) - + labels = inputs.get('labels', None) if 'labels' in inputs: try: del inputs['labels'] except (TypeError, KeyError): pass # Some dict-like types don't support deletion - + # Forward through model outputs = self._forward_step(inputs) - + inputs['labels'] = labels optimizer_config.inputs = inputs optimizer_config.outputs = outputs return outputs - + def _forward_step(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """Execute forward step with pipeline parallelism support. - + Args: inputs: Processed inputs. - + Returns: Model outputs. """ @@ -411,16 +426,16 @@ def _forward_step(self, inputs: Dict[str, Any]) -> Dict[str, Any]: return self._forward_step_pipeline(inputs) else: return self._forward_step_simple(inputs) - + def _forward_step_simple(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """Simple forward step without pipeline parallelism.""" model = self.strategy.unwrap_model(self.model) - + # Prepare inputs for Megatron input_ids = inputs.get('input_ids') attention_mask = inputs.get('attention_mask') position_ids = inputs.get('position_ids') - + # Create position_ids if not provided if position_ids is None and input_ids is not None: position_ids = torch.arange( @@ -428,46 +443,47 @@ def _forward_step_simple(self, inputs: Dict[str, Any]) -> Dict[str, Any]: device=input_ids.device, dtype=torch.long, ).unsqueeze(0).expand(input_ids.shape[0], -1) - + # Forward pass outputs = model( input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask, ) - + return {'logits': outputs} - + def _forward_step_pipeline(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """Forward step with pipeline parallelism. - + Note: For PP > 1, the forward pass is handled by Megatron's pipeline scheduler in forward_backward(). This method is for simple forward-only inference. For training, use forward_backward() which uses get_forward_backward_func(). """ from twinkle.megatron.utils import forward_step_helper - + model = self.strategy.unwrap_model(self.model) - + # Use pipeline forward helper output = forward_step_helper( model, inputs, model.config, ) - + if output is not None: return {'logits': output} return {} @remote_function() - def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajectory]], **kwargs): + def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], + List[Trajectory]], **kwargs): """Forward pass without gradient computation. - + Args: inputs: Model inputs. **kwargs: Additional arguments. - + Returns: Model outputs. """ @@ -477,23 +493,23 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T @remote_function(collect='avg') def calculate_loss(self, **kwargs): """Calculate loss from forward outputs. - + Args: **kwargs: Additional arguments including adapter_name. - + Returns: Loss value as numpy array. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] loss_instance: Loss = optimizer_config.loss_instance - + inputs = optimizer_config.inputs outputs = optimizer_config.outputs - + assert inputs is not None and outputs is not None, \ 'Cannot calculate loss of empty inputs and outputs' - + loss_value = loss_instance(inputs, outputs, **kwargs) optimizer_config.loss_value = loss_value return loss_value.detach().cpu().float().numpy() @@ -501,38 +517,42 @@ def calculate_loss(self, **kwargs): @remote_function() def backward(self, **kwargs): """Backward pass. - + Args: **kwargs: Additional arguments. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] loss_value = optimizer_config.loss_value - + assert loss_value is not None, 'Do forwarding and calculating loss before backward' - + _gas = optimizer_config.gradient_accumulation_steps if 'gradient_accumulation_steps' in kwargs: _gas = kwargs['gradient_accumulation_steps'] - + loss_value = loss_value / _gas loss_value.backward() optimizer_config.cur_step += 1 @remote_function(dispatch='all', collect='avg', sync=True) - def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], - num_microbatches: int = 1, **kwargs): + def forward_backward(self, + *, + inputs: Union[InputFeature, List[InputFeature], + Trajectory, List[Trajectory]], + num_microbatches: int = 1, + **kwargs): """Combined forward and backward pass using Megatron's scheduler. - + Note: sync=True is required for Ray mode because Megatron's pipeline parallel uses NCCL P2P communication that requires all ranks to enter the function simultaneously. - + Always uses Megatron's get_forward_backward_func() which handles: - Pipeline scheduling (1F1B, interleaved, or no-pipeline) - Communication between stages (using proper process groups for multi-tenant isolation) - Gradient accumulation across microbatches - + Args: inputs: Model inputs. Can be: - A single batch dict (num_microbatches=1) @@ -543,33 +563,34 @@ def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Tr Using num_microbatches > 1 enables Megatron's native gradient accumulation with better memory management and compute overlap. **kwargs: Additional arguments. - + Returns: Average loss value across all microbatches. """ from functools import partial from megatron.core.pipeline_parallel import get_forward_backward_func - + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] self._lazy_wrap_model() - + # Handle different input formats # 1. Single batch dict -> wrap in list # 2. List of batches -> use as-is # 3. Iterator -> convert to list if isinstance(inputs, dict): microbatch_list = [inputs] - elif hasattr(inputs, '__iter__') and not isinstance(inputs, (list, tuple)): + elif hasattr(inputs, + '__iter__') and not isinstance(inputs, (list, tuple)): # Iterator - convert to list microbatch_list = list(inputs) else: microbatch_list = list(inputs) - + # Infer num_microbatches from inputs if list is provided if len(microbatch_list) > 1: num_microbatches = len(microbatch_list) - + # Process each microbatch processed_batches = [] for batch in microbatch_list: @@ -577,85 +598,89 @@ def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Tr if isinstance(batch, dict) and 'input_ids' not in batch: if optimizer_config.template is not None: batch = optimizer_config.template.encode(batch) - + # Process inputs processor = optimizer_config.processor if processor is not None: batch = processor(batch) - + processed_batches.append(batch) - + # Get first batch for shape info (all batches should have same shape) first_batch = processed_batches[0] - + # Get CP size for sequence padding and splitting cp_size = self.strategy.cp_size cp_rank = mpu.get_context_parallel_rank() if cp_size > 1 else 0 - + # Get sequence length and batch size from first batch - original_seq_length = first_batch['input_ids'].shape[1] if 'input_ids' in first_batch else 1 - micro_batch_size = first_batch['input_ids'].shape[0] if 'input_ids' in first_batch else 1 - + original_seq_length = first_batch['input_ids'].shape[ + 1] if 'input_ids' in first_batch else 1 + micro_batch_size = first_batch['input_ids'].shape[ + 0] if 'input_ids' in first_batch else 1 + # For CP > 1, pad seq_length to be divisible by 2*cp_size if cp_size > 1: divisor = 2 * cp_size if original_seq_length % divisor != 0: - seq_length = original_seq_length + (divisor - original_seq_length % divisor) + seq_length = original_seq_length + ( + divisor - original_seq_length % divisor) else: seq_length = original_seq_length else: seq_length = original_seq_length - + def split_tensor_for_cp(tensor, dim=-1): """ Split tensor along sequence dimension for Context Parallel. - + With causal masking, split into 2*CP chunks and assign alternating chunks to balance workload across CP ranks. For CP rank i: chunks [i, 2*CP-1-i] """ if tensor is None or cp_size <= 1: return tensor - + if dim < 0: dim = (dim + tensor.ndim) % tensor.ndim - + seq_len = tensor.shape[dim] - + # Reshape to [batch, 2*cp_size, seq_per_chunk, ...] view_shape = list(tensor.shape) - view_shape[dim:dim+1] = [2 * cp_size, seq_len // (2 * cp_size)] + view_shape[dim:dim + 1] = [2 * cp_size, seq_len // (2 * cp_size)] reshaped = tensor.view(*view_shape) - + # Select chunks [cp_rank, 2*cp_size-1-cp_rank] - index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], - device='cpu', pin_memory=True).cuda(non_blocking=True) + index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], + device='cpu', + pin_memory=True).cuda(non_blocking=True) selected = reshaped.index_select(dim, index) - + # Reshape back: [batch, 2*seq_per_chunk, ...] out_shape = list(tensor.shape) out_shape[dim] = seq_len // cp_size return selected.reshape(*out_shape) - + # Define forward step function for Megatron # forward_step_func(data_iterator, model) -> (output_tensor, partial(loss_func)) def forward_step_func(data_iterator, model): batch = next(data_iterator) - + # Move tensors to CUDA with non_blocking=True for async transfer - # This matches Swift's to_device(data, 'cuda', non_blocking=True) behavior def to_cuda_non_blocking(tensor): if tensor is None: return None if isinstance(tensor, torch.Tensor) and not tensor.is_cuda: return tensor.cuda(non_blocking=True) return tensor - + input_ids = to_cuda_non_blocking(batch.get('input_ids')) position_ids = to_cuda_non_blocking(batch.get('position_ids')) attention_mask = to_cuda_non_blocking(batch.get('attention_mask')) - batch_labels = to_cuda_non_blocking(batch.get('labels')) # Labels should be in each batch - + batch_labels = to_cuda_non_blocking( + batch.get('labels')) # Labels should be in each batch + # Pad sequence for Context Parallel compatibility # Megatron's RoPE requires seq_len % (2 * cp_size) == 0 if cp_size > 1 and input_ids is not None: @@ -664,17 +689,24 @@ def to_cuda_non_blocking(tensor): if seq_len % divisor != 0: pad_len = divisor - (seq_len % divisor) # Pad input_ids - input_ids = torch.nn.functional.pad(input_ids, (0, pad_len), value=0) + input_ids = torch.nn.functional.pad(input_ids, + (0, pad_len), + value=0) # Pad labels if present if batch_labels is not None: - batch_labels = torch.nn.functional.pad(batch_labels, (0, pad_len), value=-100) + batch_labels = torch.nn.functional.pad(batch_labels, + (0, pad_len), + value=-100) # Pad attention_mask if present if attention_mask is not None: - attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_len), value=0) + attention_mask = torch.nn.functional.pad( + attention_mask, (0, pad_len), value=0) # Pad position_ids if present if position_ids is not None: - position_ids = torch.nn.functional.pad(position_ids, (0, pad_len), value=0) - + position_ids = torch.nn.functional.pad(position_ids, + (0, pad_len), + value=0) + # Create position_ids if not provided if position_ids is None and input_ids is not None: position_ids = torch.arange( @@ -682,7 +714,7 @@ def to_cuda_non_blocking(tensor): device=input_ids.device, dtype=torch.long, ).unsqueeze(0).expand(input_ids.shape[0], -1) - + # Split tensors for Context Parallel # Each CP rank processes a portion of the sequence if cp_size > 1: @@ -690,7 +722,7 @@ def to_cuda_non_blocking(tensor): position_ids = split_tensor_for_cp(position_ids, dim=-1) attention_mask = split_tensor_for_cp(attention_mask, dim=-1) batch_labels = split_tensor_for_cp(batch_labels, dim=-1) - + # Forward pass with labels - Megatron will compute loss internally # This uses Megatron's compute_language_model_loss which properly handles # vocab parallel cross entropy @@ -700,59 +732,58 @@ def to_cuda_non_blocking(tensor): attention_mask=attention_mask, labels=batch_labels, # Pass labels to let Megatron compute loss ) - + # Megatron's compute_language_model_loss returns per-token loss [batch, seq] # We need to aggregate it with loss_mask and return 3 values for proper per-token normalization - # Swift uses 3-value return: (loss, num_tokens, loss_dict) for per-token loss mode def megatron_loss_func(labels_for_mask, cp_size, output_tensor): # output_tensor is per-token loss [batch, seq] # Create loss mask from labels (ignore -100) loss_mask = (labels_for_mask != -100) - + # Compute per-token losses losses = output_tensor.float() - - # Compute sum of losses and token count (same as Swift) - # Swift: loss = torch.cat([torch.sum(losses * loss_mask).view(1), loss_mask.sum().view(1)]) + loss_sum = torch.sum(losses * loss_mask.float()) local_num_tokens = loss_mask.sum().to(torch.int) - + # For CP > 1, aggregate across CP ranks if cp_size > 1: # All-reduce loss sum and token count across CP ranks - loss_tensor = torch.cat([loss_sum.view(1), local_num_tokens.float().view(1)]) + loss_tensor = torch.cat( + [loss_sum.view(1), + local_num_tokens.float().view(1)]) torch.distributed.all_reduce( loss_tensor, op=torch.distributed.ReduceOp.SUM, - group=mpu.get_context_parallel_group() - ) + group=mpu.get_context_parallel_group()) loss_sum = loss_tensor[0] local_num_tokens = loss_tensor[1].to(torch.int) - - # Return 3 values for per-token loss mode (same as Swift): + # 1. loss (sum, will be divided by num_tokens by Megatron) # 2. local_num_tokens (for proper averaging) # 3. loss_dict for logging - reporting_loss = torch.cat([loss_sum.detach().view(1), local_num_tokens.float().view(1)]) - - return ( - loss_sum, - local_num_tokens, - {'lm loss': reporting_loss} - ) - - return output_tensor, partial(megatron_loss_func, batch_labels, cp_size) - + reporting_loss = torch.cat([ + loss_sum.detach().view(1), + local_num_tokens.float().view(1) + ]) + + return (loss_sum, local_num_tokens, { + 'lm loss': reporting_loss + }) + + return output_tensor, partial(megatron_loss_func, batch_labels, + cp_size) + # Get Megatron's forward-backward function # This automatically selects the right scheduler based on PP config: # - PP > 1: forward_backward_pipelining_without_interleaving (or with interleaving if VPP) # - PP = 1: forward_backward_no_pipelining forward_backward_func = get_forward_backward_func() - + # Create iterator over all microbatches # Megatron's scheduler will call next(data_iterator) num_microbatches times data_iter = iter(processed_batches) - + # Run forward-backward with Megatron's scheduler # Megatron handles all communication internally using proper process groups # With num_microbatches > 1, gradients are accumulated across microbatches @@ -765,23 +796,25 @@ def megatron_loss_func(labels_for_mask, cp_size, output_tensor): micro_batch_size=micro_batch_size, forward_only=False, ) - + # Extract loss from results (only last PP stage returns non-empty) # With 3-value loss_func return, each loss_dict contains 'lm loss': [loss_sum, num_tokens] # We aggregate across all microbatches using proper per-token averaging total_loss_sum = 0.0 total_num_tokens = 0 - + if losses: for loss_dict in losses: if isinstance(loss_dict, dict): # New format: 'lm loss' contains [loss_sum, num_tokens] if 'lm loss' in loss_dict: reporting = loss_dict['lm loss'] - if isinstance(reporting, torch.Tensor) and reporting.numel() == 2: + if isinstance(reporting, + torch.Tensor) and reporting.numel() == 2: total_loss_sum += reporting[0].item() total_num_tokens += int(reporting[1].item()) - elif isinstance(reporting, (list, tuple)) and len(reporting) == 2: + elif isinstance(reporting, + (list, tuple)) and len(reporting) == 2: total_loss_sum += float(reporting[0]) total_num_tokens += int(reporting[1]) # Legacy format: 'loss' contains average loss @@ -792,95 +825,100 @@ def megatron_loss_func(labels_for_mask, cp_size, output_tensor): else: total_loss_sum += float(loss_val) total_num_tokens += 1 # Fallback: treat as 1 sample - + # Compute average loss (per-token average across all microbatches) if total_num_tokens > 0: loss = total_loss_sum / total_num_tokens else: loss = total_loss_sum / max(num_microbatches, 1) - + # For PP > 1, broadcast loss from last PP stage to all ranks # Note: mpu is imported at module level, no need to reimport if mpu.get_pipeline_model_parallel_world_size() > 1: if isinstance(loss, torch.Tensor): loss_tensor = loss.detach().clone() else: - loss_tensor = torch.tensor(loss, dtype=torch.float32, device=torch.cuda.current_device()) - + loss_tensor = torch.tensor(loss, + dtype=torch.float32, + device=torch.cuda.current_device()) + # Broadcast from last PP stage (rank with pipeline_model_parallel_rank == pp_size - 1) src_rank = mpu.get_pipeline_model_parallel_last_rank() pp_group = mpu.get_pipeline_model_parallel_group() - - torch.distributed.broadcast( - loss_tensor, - src=src_rank, - group=pp_group - ) - + + torch.distributed.broadcast(loss_tensor, + src=src_rank, + group=pp_group) + loss = loss_tensor.item() - + optimizer_config.cur_step += 1 - + # Note: finalize_model_grads is called inside forward_backward_func # which already handles gradient synchronization across DP replicas. # No additional barrier is needed here - adding one would hurt performance. - + if isinstance(loss, torch.Tensor): return loss.detach().cpu().float().numpy() return float(loss) @remote_function(dispatch='all') - def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type: int = 2, **kwargs): + def clip_grad_norm(self, + max_grad_norm: float = 1.0, + norm_type: int = 2, + **kwargs): """Clip gradient norm. - + Args: max_grad_norm: Maximum gradient norm. norm_type: Type of norm to use. **kwargs: Additional arguments. - + Returns: Total norm of gradients. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] - + # Check if using Megatron optimizer (handles clip_grad internally) - is_megatron_opt = getattr(optimizer_config, 'is_megatron_optimizer', False) + is_megatron_opt = getattr(optimizer_config, 'is_megatron_optimizer', + False) if is_megatron_opt: # Megatron optimizer handles gradient clipping in step() # Return the grad_norm from last step if available return getattr(optimizer_config, '_last_grad_norm', 0.0) - + parameters = self._get_trainable_parameters(adapter_name).values() - + return torch.nn.utils.clip_grad_norm_( - parameters, max_grad_norm, norm_type=norm_type - ).detach().cpu().numpy() + parameters, max_grad_norm, + norm_type=norm_type).detach().cpu().numpy() @remote_function(dispatch='all') def step(self, **kwargs): """Optimizer step. - + For DDP-wrapped models: - Gradients are synchronized automatically during backward via DDP - + For non-DDP models (e.g., PEFT/LoRA): - Gradients are NOT synchronized across DP ranks - Each DP replica trains independently with different data - This is a common pattern for PEFT training where the overhead of gradient averaging is not worth the benefit - + Note: Uses dispatch='all' to ensure all workers execute this method. - + Args: **kwargs: Additional arguments. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] - - if not optimizer_config.do_grad_sync(kwargs.get('gradient_accumulation_steps')): + + if not optimizer_config.do_grad_sync( + kwargs.get('gradient_accumulation_steps')): return - + # For DDP-wrapped models, gradients are already synchronized during backward if self._is_model_ddp_wrapped(): # For Megatron DDP, ensure gradient buffers are finalized @@ -888,12 +926,13 @@ def step(self, **kwargs): self.model.finish_grad_sync() # For non-DDP models (e.g., PEFT), we skip gradient synchronization # Each DP replica trains independently, which is acceptable for PEFT - + optimizer = optimizer_config.optimizer assert optimizer is not None, 'Set optimizer correctly before stepping' - + # Check if using Megatron optimizer (has different step() signature) - is_megatron_opt = getattr(optimizer_config, 'is_megatron_optimizer', False) + is_megatron_opt = getattr(optimizer_config, 'is_megatron_optimizer', + False) if is_megatron_opt: # Megatron optimizer step() returns (success, grad_norm, num_zeros) success, grad_norm, num_zeros = optimizer.step() @@ -902,156 +941,51 @@ def step(self, **kwargs): optimizer_config._last_step_success = success else: optimizer.step(**kwargs) - + def _is_model_ddp_wrapped(self) -> bool: """Check if model is wrapped with DDP. - + Returns: True if model is wrapped with DDP (either Megatron DDP, LoRA DDP, or PyTorch DDP). """ from torch.nn.parallel import DistributedDataParallel as TorchDDP return isinstance(self.model, (MegatronDDP, TorchDDP)) - + def _get_unwrapped_model(self) -> nn.Module: """Get the unwrapped model. - + Returns: The base model without DDP wrapper. """ return self.strategy.unwrap_model(self.model) - - @remote_function(dispatch='all') - def wrap_with_lora_ddp( - self, - adapter_name: str = _default_adapter_name, - overlap_grad_reduce: bool = True, - bucket_size: Optional[int] = None, - lora_param_patterns: Optional[set] = None, - **kwargs - ): - """ - Wrap the model with LoRA-aware DDP for efficient distributed training. - - This enables: - 1. Communication-computation overlap: Gradient all-reduce starts while - backward pass is still computing other gradients. - 2. Gradient bucketing: Small gradients are grouped for efficient communication. - 3. Async gradient reduction: Non-blocking communication operations. - - Should be called AFTER add_adapter_to_model() and BEFORE training starts. - - Args: - adapter_name: Name of the adapter (for multi-adapter scenarios). - overlap_grad_reduce: Enable communication-computation overlap. - Set to True for best performance (default). - bucket_size: Size of gradient buckets in number of elements. - None for automatic sizing based on LoRA parameter count. - lora_param_patterns: Set of patterns to identify LoRA parameters. - Default: {'lora_A', 'lora_B', 'lora_'} - **kwargs: Additional arguments passed to DDP config. - - use_distributed_optimizer: bool (default False for LoRA) - - grad_reduce_in_fp32: bool (default False) - - Returns: - self for method chaining. - - Example: - >>> model = MegatronModel(...) - >>> model.add_adapter_to_model('lora', lora_config) - >>> model.wrap_with_lora_ddp( - ... adapter_name='lora', - ... overlap_grad_reduce=True, - ... ) - >>> # Now training will use optimized DDP - >>> for batch in dataloader: - ... loss = model.forward_backward(inputs=batch) - ... model.step() - ... model.zero_grad() - """ - from twinkle.megatron.distributed import wrap_model_with_lora_ddp - from megatron.core.distributed import DistributedDataParallelConfig - - # Check if already wrapped - if self._is_model_ddp_wrapped(): - if mpu.get_data_parallel_rank() == 0: - print("Warning: Model is already DDP wrapped. Skipping wrap_with_lora_ddp().") - return self - - # Get the transformer config from the bridge initializer - transformer_config = getattr(self, '_transformer_config', None) - if transformer_config is None: - # Try to get from strategy - if hasattr(self.strategy, 'transformer_config'): - transformer_config = self.strategy.transformer_config - else: - raise ValueError( - "Cannot find TransformerConfig. " - "Make sure model is created via MegatronModel." - ) - - # Create DDP config - ddp_config = DistributedDataParallelConfig( - overlap_grad_reduce=overlap_grad_reduce, - use_distributed_optimizer=kwargs.get('use_distributed_optimizer', False), - grad_reduce_in_fp32=kwargs.get('grad_reduce_in_fp32', False), - bucket_size=bucket_size, - average_in_collective=kwargs.get('average_in_collective', False), - ) - - # Get tenant process group if multi-tenant - tenant_process_group = kwargs.get('tenant_process_group', None) - - # Wrap model - self.model = wrap_model_with_lora_ddp( - model=self.model, - config=transformer_config, - ddp_config=ddp_config, - lora_param_patterns=lora_param_patterns, - tenant_id=adapter_name, - tenant_process_group=tenant_process_group, - ) - - # CRITICAL: Update transformer_config.no_sync_func to use the DDP's no_sync - # This is needed for Megatron's forward_backward_func to properly control - # gradient synchronization during gradient accumulation - transformer_config.no_sync_func = self.model.no_sync - - # Also update finalize_model_grads_func to use the DDP's finish_grad_sync - # instead of the custom PEFT version - def finalize_model_grads_for_ddp(model_list, *args, **kwargs): - """Finalize gradients for DDP-wrapped model.""" - for model_chunk in model_list: - if hasattr(model_chunk, 'finish_grad_sync'): - model_chunk.finish_grad_sync() - transformer_config.finalize_model_grads_func = finalize_model_grads_for_ddp - - return self @remote_function(dispatch='all') def zero_grad(self, **kwargs): """Zero gradients. - + For DDP-wrapped models, also zeros the DDP gradient buffers. - + Note: For DDP-wrapped models, zero_grad_buffer() is always called because it's essential for the next training iteration. The do_grad_sync check only affects the optimizer.zero_grad() call. - + Args: **kwargs: Additional arguments. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] - + # For DDP-wrapped models, ALWAYS zero the gradient buffer # This is essential because Megatron's forward_backward_func uses # the buffer's state to track gradient accumulation - if self._is_model_ddp_wrapped() and hasattr(self.model, 'zero_grad_buffer'): + if self._is_model_ddp_wrapped() and hasattr(self.model, + 'zero_grad_buffer'): self.model.zero_grad_buffer() - - if not optimizer_config.do_grad_sync(kwargs.get('gradient_accumulation_steps')): + + if not optimizer_config.do_grad_sync( + kwargs.get('gradient_accumulation_steps')): return - + optimizer = optimizer_config.optimizer if optimizer is not None: # Clear set_to_none for better compatibility @@ -1060,16 +994,17 @@ def zero_grad(self, **kwargs): @remote_function() def lr_step(self, **kwargs): """Learning rate scheduler step. - + Args: **kwargs: Additional arguments. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] - - if not optimizer_config.do_grad_sync(kwargs.get('gradient_accumulation_steps')): + + if not optimizer_config.do_grad_sync( + kwargs.get('gradient_accumulation_steps')): return - + lr_scheduler = optimizer_config.lr_scheduler if lr_scheduler is not None: lr_scheduler.step(**kwargs) @@ -1077,22 +1012,22 @@ def lr_step(self, **kwargs): @remote_function(dispatch='all') def set_loss(self, loss_cls: Union[Type[Loss], str], **kwargs): """Set loss function. - + NOTE: For MegatronModel, the loss is computed internally by Megatron's GPTModel when labels are passed. This method is kept for API compatibility but the provided loss_cls is NOT used during forward_backward. - + Megatron internally uses vocab_parallel_cross_entropy which correctly handles tensor parallelism. This design ensures Loss classes don't need to be aware of the training backend (Megatron vs Transformers). - + Args: loss_cls: Loss class or string name (not used for Megatron). **kwargs: Additional arguments. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] - + if isinstance(loss_cls, str): if hasattr(twinkle.loss, loss_cls): loss_cls = getattr(twinkle.loss, loss_cls) @@ -1102,9 +1037,10 @@ def set_loss(self, loss_cls: Union[Type[Loss], str], **kwargs): optimizer_config.loss_instance = loss_cls() @remote_function(dispatch='all') - def set_optimizer(self, optimizer_cls: Union[Type[Optimizer], str], **kwargs): + def set_optimizer(self, optimizer_cls: Union[Type[Optimizer], str], + **kwargs): """Set optimizer. - + Args: optimizer_cls: Optimizer class or string name. - Standard PyTorch optimizers: 'AdamW', 'Adam', 'SGD', etc. @@ -1115,30 +1051,31 @@ def set_optimizer(self, optimizer_cls: Union[Type[Optimizer], str], **kwargs): """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] - + # Check if requesting Megatron distributed optimizer - if optimizer_cls == 'MegatronDistributed' or kwargs.pop('use_megatron_optimizer', False): - optimizer_config.optimizer = self._create_megatron_optimizer(**kwargs) + if optimizer_cls == 'MegatronDistributed' or kwargs.pop( + 'use_megatron_optimizer', False): + optimizer_config.optimizer = self._create_megatron_optimizer( + **kwargs) optimizer_config.is_megatron_optimizer = True return - + if isinstance(optimizer_cls, str): if hasattr(torch.optim, optimizer_cls): optimizer_cls = getattr(torch.optim, optimizer_cls) else: optimizer_cls = Plugin.load_plugin(optimizer_cls, Optimizer) - + optimizer_config.optimizer = optimizer_cls( - self._get_trainable_parameters(adapter_name).values(), **kwargs - ) + self._get_trainable_parameters(adapter_name).values(), **kwargs) optimizer_config.is_megatron_optimizer = False - + def _create_megatron_optimizer(self, **kwargs): """Create Megatron distributed optimizer. - + This provides significant memory savings for large models by sharding optimizer states across DP replicas. - + Args: **kwargs: Optimizer configuration options. - lr: Learning rate (default: 1e-4) @@ -1147,16 +1084,17 @@ def _create_megatron_optimizer(self, **kwargs): - clip_grad: Gradient clipping threshold (default: 1.0) - bf16: Use bf16 training (default: True) - adam_beta1, adam_beta2, adam_eps: Adam parameters - + Returns: MegatronOptimizer instance. """ from megatron.core.optimizer import get_megatron_optimizer, OptimizerConfig - + # Build optimizer config lr = kwargs.get('lr', 1e-4) - use_distributed_optimizer = kwargs.get('use_distributed_optimizer', True) - + use_distributed_optimizer = kwargs.get('use_distributed_optimizer', + True) + opt_config = OptimizerConfig( optimizer='adam', lr=lr, @@ -1171,20 +1109,21 @@ def _create_megatron_optimizer(self, **kwargs): overlap_param_gather=kwargs.get('overlap_param_gather', False), log_num_zeros_in_grad=kwargs.get('log_num_zeros_in_grad', False), ) - + # For PEFT models, we need to handle the case where model is not DDP-wrapped # We create a temporary wrapper to satisfy Megatron's optimizer requirements model_chunks = [self.model] - + # Check if model has ddp_config (required for distributed optimizer) if not hasattr(self.model, 'ddp_config') and use_distributed_optimizer: # For PEFT models without DDP, fall back to non-distributed optimizer # but still use Megatron's optimized implementation opt_config.use_distributed_optimizer = False if mpu.get_data_parallel_rank() == 0: - print("Note: Falling back to non-distributed optimizer for PEFT model. " - "For distributed optimizer, wrap model with MegatronDDP.") - + print( + 'Note: Falling back to non-distributed optimizer for PEFT model. ' + 'For distributed optimizer, wrap model with MegatronDDP.') + try: optimizer = get_megatron_optimizer( config=opt_config, @@ -1194,23 +1133,31 @@ def _create_megatron_optimizer(self, **kwargs): except Exception as e: # Fallback to simple FP32 optimizer if Megatron optimizer fails if mpu.get_data_parallel_rank() == 0: - print(f"Warning: Failed to create Megatron optimizer ({e}), falling back to PyTorch AdamW") - + print( + f'Warning: Failed to create Megatron optimizer ({e}), falling back to PyTorch AdamW' + ) + params = [p for p in self.model.parameters() if p.requires_grad] - return torch.optim.AdamW(params, lr=lr, weight_decay=kwargs.get('weight_decay', 0.0)) + return torch.optim.AdamW(params, + lr=lr, + weight_decay=kwargs.get( + 'weight_decay', 0.0)) - def _get_trainable_parameters(self, adapter_name: str = _default_adapter_name) -> Dict[str, nn.Parameter]: + def _get_trainable_parameters( + self, + adapter_name: str = _default_adapter_name + ) -> Dict[str, nn.Parameter]: """Get trainable parameters. - + Args: adapter_name: Name of adapter. - + Returns: Dict mapping parameter names to parameters. """ is_default = adapter_name == _default_adapter_name pattern = re.compile(rf'\.lora_\w+\.{re.escape(adapter_name)}\.') - + params = {} model = self.strategy.unwrap_model(self.model) for name, param in model.named_parameters(): @@ -1219,22 +1166,24 @@ def _get_trainable_parameters(self, adapter_name: str = _default_adapter_name) - return params @remote_function(dispatch='all') - def set_lr_scheduler(self, scheduler_cls: Union[Type[LRScheduler], str], **kwargs): + def set_lr_scheduler(self, scheduler_cls: Union[Type[LRScheduler], str], + **kwargs): """Set learning rate scheduler. - + Args: scheduler_cls: Scheduler class or string name. **kwargs: Additional arguments. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] - + if isinstance(scheduler_cls, str): if hasattr(torch.optim.lr_scheduler, scheduler_cls): - scheduler_cls = getattr(torch.optim.lr_scheduler, scheduler_cls) + scheduler_cls = getattr(torch.optim.lr_scheduler, + scheduler_cls) else: scheduler_cls = Plugin.load_plugin(scheduler_cls, LRScheduler) - + optimizer = optimizer_config.optimizer assert optimizer is not None, 'Set optimizer before setting lr_scheduler' optimizer_config.lr_scheduler = scheduler_cls(optimizer, **kwargs) @@ -1242,57 +1191,58 @@ def set_lr_scheduler(self, scheduler_cls: Union[Type[LRScheduler], str], **kwarg @remote_function(dispatch='all', sync=True) def save(self, output_dir: str, **kwargs): """Save model checkpoint. - + Args: output_dir: Output directory. **kwargs: Additional arguments. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) save_format = kwargs.pop('save_format', 'hf') # 'hf' or 'megatron' - + if save_format == 'hf': self._save_hf_format(output_dir, adapter_name) else: self._save_megatron_format(output_dir, adapter_name) - + self._save_tokenizer(output_dir, adapter_name) - + def _save_hf_format(self, output_dir: str, adapter_name: str): """Save in HuggingFace format using bridge adapter. - + For distributed training: - All PP ranks participate in export (each has different layers) - Only DP rank 0 actually writes to disk - Uses barrier for synchronization - + For LoRA training: - Saves in PEFT format (adapter_model.safetensors + adapter_config.json) """ from twinkle.megatron.model.bridge import TwinkleBridgeAdapter import os - + # Check if this is LoRA training (has adapter_name other than default) is_lora = adapter_name and adapter_name != '' is_peft_format = is_lora - + # Create output directory on rank 0 only try: from megatron.core import parallel_state as mpu - dp_rank = mpu.get_data_parallel_rank() if mpu.is_initialized() else 0 + dp_rank = mpu.get_data_parallel_rank() if mpu.is_initialized( + ) else 0 except (ImportError, AssertionError): dp_rank = 0 - + if dp_rank == 0: os.makedirs(output_dir, exist_ok=True) - + # Synchronize before saving if dist.is_initialized(): dist.barrier() - + # Calculate padded vocab size padded_vocab_size = self._pad_vocab_size(self.hf_config.vocab_size) \ if hasattr(self, '_pad_vocab_size') else None - + # Use TwinkleBridgeAdapter for weight conversion # All ranks participate - bridge handles which ranks write adapter = TwinkleBridgeAdapter( @@ -1300,42 +1250,47 @@ def _save_hf_format(self, output_dir: str, adapter_name: str): tp_size=self.strategy.tp_size, pp_size=self.strategy.pp_size, ep_size=self.strategy.ep_size, - model_path=self._model_path if hasattr(self, '_model_path') else self.model_id, + model_path=self._model_path + if hasattr(self, '_model_path') else self.model_id, padded_vocab_size=padded_vocab_size, ) - + # Get the model (unwrap if DDP wrapped) model = self.strategy.unwrap_model(self.model) - + # Use bridge to save weights - adapter.save_weights([model], output_dir, is_peft_format=is_peft_format) - + adapter.save_weights([model], + output_dir, + is_peft_format=is_peft_format) + # Save config on rank 0 only if dp_rank == 0: self.hf_config.save_pretrained(output_dir) - + def _pad_vocab_size(self, vocab_size: int) -> int: """Pad vocab size for tensor parallelism.""" divisor = self.strategy.tp_size * 128 return ((vocab_size + divisor - 1) // divisor) * divisor - + def _save_megatron_format(self, output_dir: str, adapter_name: str): """Save in Megatron checkpoint format.""" import os os.makedirs(output_dir, exist_ok=True) - + model = self.strategy.unwrap_model(self.model) state_dict = self._get_trainable_parameters(adapter_name) - + # Convert to CPU cpu_state_dict = {k: v.cpu() for k, v in state_dict.items()} - + # Save with rank info for distributed checkpointing rank = dist.get_rank() if dist.is_initialized() else 0 checkpoint_path = os.path.join(output_dir, f'model_rank{rank}.pt') torch.save(cpu_state_dict, checkpoint_path) - - def _save_tokenizer(self, output_dir: str, adapter_name: str = _default_adapter_name): + + def _save_tokenizer(self, + output_dir: str, + adapter_name: str = _default_adapter_name): """Save tokenizer.""" optimizer_config = self.optimizer_group.get(adapter_name) if optimizer_config and optimizer_config.template: @@ -1344,10 +1299,10 @@ def _save_tokenizer(self, output_dir: str, adapter_name: str = _default_adapter_ @remote_function(execute='first') def get_state_dict(self, **kwargs): """Get trainable state dict. - + Args: **kwargs: Additional arguments. - + Returns: State dict of trainable parameters. """ @@ -1355,24 +1310,24 @@ def get_state_dict(self, **kwargs): return self._get_trainable_parameters(adapter_name) _peft_patched = False - + @classmethod def _patch_peft_for_megatron(cls): """Patch PEFT's BaseTuner to handle Megatron's TransformerConfig. - + Megatron's TransformerConfig doesn't have a .get() method like HuggingFace configs. This patch handles the AttributeError that occurs when PEFT tries to check tie_word_embeddings. """ if cls._peft_patched: return - + from typing import List import torch.nn as nn from peft.tuners.tuners_utils import BaseTuner - + _origin_get_tied_target_modules = BaseTuner._get_tied_target_modules - + def _get_tied_target_modules(self, model: nn.Module) -> List[str]: try: return _origin_get_tied_target_modules(self, model) @@ -1380,13 +1335,16 @@ def _get_tied_target_modules(self, model: nn.Module) -> List[str]: # Megatron's TransformerConfig doesn't have .get() method # Check share_embeddings_and_output_weights instead tied_target_modules = [] - if getattr(model, 'share_embeddings_and_output_weights', False): + if getattr(model, 'share_embeddings_and_output_weights', + False): for target_module in self.targeted_module_names: module_name = target_module.split('.')[-1] - if module_name in ['output_layer', 'embedding', 'word_embeddings']: + if module_name in [ + 'output_layer', 'embedding', 'word_embeddings' + ]: tied_target_modules.append(target_module) return tied_target_modules - + BaseTuner._get_tied_target_modules = _get_tied_target_modules cls._peft_patched = True @@ -1398,73 +1356,77 @@ def add_adapter_to_model( **kwargs, ): """Add LoRA adapter to model. - + Args: adapter_name: Name of the adapter. config_or_dir: LoRA config or path to saved adapter. **kwargs: Additional arguments. """ - from twinkle.megatron.utils import ( - prepare_lora_model, patch_deepcopy, get_target_modules, set_linear_is_expert - ) - + from twinkle.megatron.utils import (prepare_lora_model, patch_deepcopy, + get_target_modules, + set_linear_is_expert) + # Patch PEFT BaseTuner to handle Megatron's TransformerConfig # which doesn't have a .get() method like HuggingFace configs self._patch_peft_for_megatron() - + assert adapter_name, 'Use a non-empty adapter_name' - + model = self.strategy.unwrap_model(self.model) - + # Mark expert layers for MoE models set_linear_is_expert(model) - + if isinstance(config_or_dir, str): # Load from path config_or_dir = HubOperation.download_model(config_or_dir) from peft import PeftModel - model = PeftModel.from_pretrained( - model, config_or_dir, adapter_name=adapter_name, - is_trainable=kwargs.get('is_trainable', True) - ) + model = PeftModel.from_pretrained(model, + config_or_dir, + adapter_name=adapter_name, + is_trainable=kwargs.get( + 'is_trainable', True)) else: # Create from config from peft import LoraConfig, get_peft_model - + if not isinstance(config_or_dir, LoraConfig): # Convert dict to LoraConfig config_or_dir = LoraConfig(**config_or_dir) - + # Expand target_modules (e.g., 'all-linear' -> actual module names) if config_or_dir.target_modules: if isinstance(config_or_dir.target_modules, str): target_modules = [config_or_dir.target_modules] else: target_modules = list(config_or_dir.target_modules) - + expanded_modules = get_target_modules(model, target_modules) config_or_dir.target_modules = expanded_modules - + with patch_deepcopy(): - model = get_peft_model(model, config_or_dir, adapter_name=adapter_name) - + model = get_peft_model(model, + config_or_dir, + adapter_name=adapter_name) + # Update model reference if self._model_wrapped: if isinstance(self.model, MegatronDDP): self.model.module = model else: self.model = model - + # Add finish_grad_sync method for Megatron's finalize_model_grads compatibility # This is needed because Megatron's forward_backward_func calls finish_grad_sync # on model chunks, but PEFT models don't have this method by default if not hasattr(self.model, 'finish_grad_sync'): + def finish_grad_sync(): """Synchronize gradients across DP ranks for non-DDP models. - + This is a compatibility shim for Megatron's finalize_model_grads. For PEFT/LoRA models, we manually all-reduce only trainable (LoRA) gradients. - + Optimizations: 1. Only process gradients of trainable parameters (LoRA weights) 2. Skip if DP size is 1 (no synchronization needed) @@ -1473,122 +1435,134 @@ def finish_grad_sync(): dp_world_size = mpu.get_data_parallel_world_size() if dp_world_size <= 1: return # No sync needed for DP=1 - - dp_cp_group = mpu.get_data_parallel_group(with_context_parallel=True) + + dp_cp_group = mpu.get_data_parallel_group( + with_context_parallel=True) grads = [] - + # Only collect gradients from trainable parameters (LoRA weights) # This is much faster than iterating all parameters for param in self.model.parameters(): if param.requires_grad and param.grad is not None: grads.append(param.grad.data) - + if not grads: return # No gradients to sync - + # Coalesced all-reduce for efficiency from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors coalesced = _flatten_dense_tensors(grads) - dist.all_reduce(coalesced, op=dist.ReduceOp.AVG, group=dp_cp_group) - + dist.all_reduce(coalesced, + op=dist.ReduceOp.AVG, + group=dp_cp_group) + # Copy back synchronized gradients - for grad, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): + for grad, synced in zip( + grads, _unflatten_dense_tensors(coalesced, grads)): grad.copy_(synced) - + self.model.finish_grad_sync = finish_grad_sync - + # Create optimizer group for adapter self.optimizer_group[adapter_name] = MegatronOptimizerGroup() self.optimizer_group[adapter_name].adapter_name = adapter_name self.optimizer_group[adapter_name].adapter_config = config_or_dir - self.optimizer_group[adapter_name].gradient_accumulation_steps = kwargs.get( - 'gradient_accumulation_steps', 1 - ) - + self.optimizer_group[ + adapter_name].gradient_accumulation_steps = kwargs.get( + 'gradient_accumulation_steps', 1) + # Copy settings from default default_config = self.optimizer_group.get(_default_adapter_name) if default_config: if default_config.template: - self.optimizer_group[adapter_name].template = default_config.template + self.optimizer_group[ + adapter_name].template = default_config.template if default_config.processor: - self.optimizer_group[adapter_name].processor = default_config.processor + self.optimizer_group[ + adapter_name].processor = default_config.processor if default_config.loss_instance: - self.optimizer_group[adapter_name].loss_instance = default_config.loss_instance + self.optimizer_group[ + adapter_name].loss_instance = default_config.loss_instance @remote_function(dispatch='all') - def set_template(self, template_cls: Union[Type[template.Template], str], **kwargs): + def set_template(self, template_cls: Union[Type[template.Template], str], + **kwargs): """Set template for input encoding. - + Args: template_cls: Template class or string name. **kwargs: Additional arguments. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] - + if isinstance(template_cls, str): if hasattr(template, template_cls): template_cls = getattr(template, template_cls) else: - template_cls = Plugin.load_plugin(template_cls, template.Template) + template_cls = Plugin.load_plugin(template_cls, + template.Template) optimizer_config.template = template_cls(self.model_id, **kwargs) @remote_function(dispatch='all') - def set_processor(self, processor_cls: Union[Type[InputProcessor], str], **kwargs): + def set_processor(self, processor_cls: Union[Type[InputProcessor], str], + **kwargs): """Set input processor. - + Args: processor_cls: Processor class or string name. **kwargs: Additional arguments. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] - + if isinstance(processor_cls, str): if hasattr(twinkle.processor, processor_cls): processor_cls = getattr(twinkle.processor, processor_cls) else: - processor_cls = Plugin.load_plugin(processor_cls, InputProcessor) - optimizer_config.processor = processor_cls(device_mesh=self.device_mesh, **kwargs) + processor_cls = Plugin.load_plugin(processor_cls, + InputProcessor) + optimizer_config.processor = processor_cls( + device_mesh=self.device_mesh, **kwargs) @remote_function(execute='first') def get_train_configs(self, **kwargs): """Get training configuration summary. - + Args: **kwargs: Additional arguments. - + Returns: Configuration summary string. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] - + expr = f'Backend: Megatron-Core\n' expr += f'TP size: {self.strategy.tp_size}\n' expr += f'PP size: {self.strategy.pp_size}\n' expr += f'CP size: {self.strategy.cp_size}\n' expr += f'EP size: {self.strategy.ep_size}\n' expr += f'Sequence Parallel: {self.strategy.sequence_parallel}\n' - + if optimizer_config.adapter_config is not None: config = optimizer_config.adapter_config.__dict__ - config = {key: str(value) for key, value in config.items() if value is not None} + config = { + key: str(value) + for key, value in config.items() if value is not None + } expr += f'Adapter config:\n{json.dumps(config, indent=2, ensure_ascii=False)}\n' - + if optimizer_config.optimizer: expr += f'Optimizer: {optimizer_config.optimizer.__class__.__name__}\n' expr += f'Learning rate: {optimizer_config.optimizer.defaults.get("lr", "N/A")}\n' if optimizer_config.lr_scheduler: expr += f'LR scheduler: {optimizer_config.lr_scheduler.__class__.__name__}\n' expr += f'Gradient accumulation steps: {optimizer_config.gradient_accumulation_steps}\n' - + return expr - - def __repr__(self): - return ( - f"MegatronModel(model_id='{self.model_id}', " - f"tp={self.strategy.tp_size}, pp={self.strategy.pp_size}, " - f"cp={self.strategy.cp_size}, ep={self.strategy.ep_size})" - ) + def __repr__(self): + return (f"MegatronModel(model_id='{self.model_id}', " + f'tp={self.strategy.tp_size}, pp={self.strategy.pp_size}, ' + f'cp={self.strategy.cp_size}, ep={self.strategy.ep_size})') diff --git a/src/twinkle/model/strategy/megatron.py b/src/twinkle/model/strategy/megatron.py index 1b721083..5f0ad751 100644 --- a/src/twinkle/model/strategy/megatron.py +++ b/src/twinkle/model/strategy/megatron.py @@ -3,8 +3,8 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union import torch -import torch.nn as nn import torch.distributed as dist +import torch.nn as nn from .base import TrainStrategy @@ -19,7 +19,8 @@ from megatron.core.distributed import DistributedDataParallel as MegatronDDP from packaging import version MEGATRON_AVAILABLE = True - mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') + mcore_013 = version.parse( + megatron.core.__version__) >= version.parse('0.13.0rc0') except ImportError: MEGATRON_AVAILABLE = False mcore_013 = False @@ -29,21 +30,19 @@ def check_megatron_available(): """Check if Megatron-Core is available.""" if not MEGATRON_AVAILABLE: raise ImportError( - "Megatron-Core is not installed. Please install it with: " - "pip install megatron-core" - ) + 'Megatron-Core is not installed. Please install it with: ' + 'pip install megatron-core') class MegatronStrategy(TrainStrategy): """Strategy for Megatron-Core based distributed training. - + Supports Tensor Parallel (TP), Pipeline Parallel (PP), Context Parallel (CP), Expert Parallel (EP), and Data Parallel (DP). - + This strategy integrates with twinkle's DeviceMesh to provide a unified interface for distributed training configuration. """ - def __init__( self, tensor_model_parallel_size: int = 1, @@ -60,7 +59,7 @@ def __init__( megatron_args: Optional[Dict[str, Any]] = None, ): """Initialize MegatronStrategy. - + Args: tensor_model_parallel_size: Degree of tensor model parallelism. pipeline_model_parallel_size: Degree of pipeline model parallelism. @@ -76,14 +75,18 @@ def __init__( megatron_args: Additional Megatron arguments. """ check_megatron_available() - + # If device_mesh is provided, extract parallel sizes from it if device_mesh is not None: - tensor_model_parallel_size = self._get_dim_from_mesh(device_mesh, 'tp', tensor_model_parallel_size) - pipeline_model_parallel_size = self._get_dim_from_mesh(device_mesh, 'pp', pipeline_model_parallel_size) - context_parallel_size = self._get_dim_from_mesh(device_mesh, 'cp', context_parallel_size) - expert_model_parallel_size = self._get_dim_from_mesh(device_mesh, 'ep', expert_model_parallel_size) - + tensor_model_parallel_size = self._get_dim_from_mesh( + device_mesh, 'tp', tensor_model_parallel_size) + pipeline_model_parallel_size = self._get_dim_from_mesh( + device_mesh, 'pp', pipeline_model_parallel_size) + context_parallel_size = self._get_dim_from_mesh( + device_mesh, 'cp', context_parallel_size) + expert_model_parallel_size = self._get_dim_from_mesh( + device_mesh, 'ep', expert_model_parallel_size) + self.tp_size = tensor_model_parallel_size self.pp_size = pipeline_model_parallel_size self.cp_size = context_parallel_size @@ -96,19 +99,20 @@ def __init__( self.params_dtype = params_dtype self.device_mesh = device_mesh self.megatron_args = megatron_args or {} - + self._initialized = False self._parallel_state = None - + @staticmethod - def _get_dim_from_mesh(device_mesh: 'DeviceMesh', dim_name: str, default: int) -> int: + def _get_dim_from_mesh(device_mesh: 'DeviceMesh', dim_name: str, + default: int) -> int: """Get dimension size from device mesh. - + Args: device_mesh: The device mesh. dim_name: Name of the dimension. default: Default value if dimension not found. - + Returns: Dimension size. """ @@ -128,14 +132,14 @@ def from_device_mesh( **kwargs, ) -> 'MegatronStrategy': """Create MegatronStrategy from twinkle DeviceMesh. - + Args: device_mesh: Twinkle DeviceMesh with dimension names like 'tp', 'pp', 'cp', 'ep', 'dp'. sequence_parallel: Enable sequence parallelism. use_distributed_optimizer: Use Megatron's distributed optimizer. mixed_precision: Mixed precision mode. **kwargs: Additional arguments. - + Returns: MegatronStrategy instance. """ @@ -149,29 +153,29 @@ def from_device_mesh( def initialize(self, **kwargs) -> None: """Initialize Megatron parallel state. - + This method handles both local (torchrun) and Ray modes: - - **Local mode**: + + **Local mode**: - torch.distributed is already initialized by torchrun - Just initialize mpu.initialize_model_parallel() - + **Ray mode**: - Read RANK, WORLD_SIZE, MASTER_ADDR, MASTER_PORT from environment - Initialize torch.distributed with these values - Then initialize mpu.initialize_model_parallel() - + This allows the same MegatronModel code to work in both modes. """ if self._initialized: return - + import os from datetime import timedelta - + # Determine execution mode twinkle_mode = os.environ.get('TWINKLE_MODE', 'local') - + # Initialize torch.distributed if not already done if not dist.is_initialized(): if twinkle_mode == 'ray': @@ -181,10 +185,10 @@ def initialize(self, **kwargs) -> None: master_addr = os.environ.get('MASTER_ADDR', 'localhost') master_port = os.environ.get('MASTER_PORT', '29500') local_rank = int(os.environ.get('LOCAL_RANK', '0')) - + # Set CUDA device before init_process_group torch.cuda.set_device(local_rank) - + # Initialize process group dist.init_process_group( backend='nccl', @@ -197,38 +201,37 @@ def initialize(self, **kwargs) -> None: # Local mode: torchrun should have set up distributed # If not, initialize with default settings dist.init_process_group(backend='nccl') - + world_size = dist.get_world_size() - + # Validate parallel configuration total_model_parallel = self.tp_size * self.pp_size * self.cp_size if world_size % total_model_parallel != 0: raise ValueError( - f"World size ({world_size}) must be divisible by " - f"tp_size * pp_size * cp_size ({total_model_parallel})" - ) - + f'World size ({world_size}) must be divisible by ' + f'tp_size * pp_size * cp_size ({total_model_parallel})') + # Initialize Megatron parallel state init_kwargs = { 'tensor_model_parallel_size': self.tp_size, 'pipeline_model_parallel_size': self.pp_size, 'context_parallel_size': self.cp_size, } - + if self.vp_size is not None: init_kwargs['virtual_pipeline_model_parallel_size'] = self.vp_size - + # Handle MoE parallelism if self.ep_size > 1: init_kwargs['expert_model_parallel_size'] = self.ep_size if mcore_013: init_kwargs['expert_tensor_parallel_size'] = self.etp_size - + parallel_state.initialize_model_parallel(**init_kwargs) - + self._parallel_state = parallel_state self._initialized = True - + # Set CUDA device (may be redundant in Ray mode, but safe) local_rank = dist.get_rank() % torch.cuda.device_count() torch.cuda.set_device(local_rank) @@ -337,7 +340,7 @@ def is_data_parallel_main_rank(self) -> bool: def get_params_dtype(self) -> torch.dtype: """Get parameter dtype based on configuration. - + Returns: PyTorch dtype for model parameters. """ @@ -348,7 +351,7 @@ def get_params_dtype(self) -> torch.dtype: 'bf16': torch.bfloat16, } return dtype_map.get(self.params_dtype, torch.bfloat16) - + if self.mixed_precision == 'bf16': return torch.bfloat16 elif self.mixed_precision == 'fp16': @@ -357,42 +360,47 @@ def get_params_dtype(self) -> torch.dtype: def _get_transformer_config(self, model: nn.Module): """Get TransformerConfig from model, handling PEFT wrappers. - + Args: model: The model (may be wrapped with PEFT). - + Returns: TransformerConfig if found, None otherwise. """ # Direct config attribute config = getattr(model, 'config', None) - if config is not None and hasattr(config, 'tensor_model_parallel_size'): + if config is not None and hasattr(config, + 'tensor_model_parallel_size'): return config - + # PEFT model: model.base_model.model.config if hasattr(model, 'base_model'): base = model.base_model if hasattr(base, 'model'): config = getattr(base.model, 'config', None) - if config is not None and hasattr(config, 'tensor_model_parallel_size'): + if config is not None and hasattr( + config, 'tensor_model_parallel_size'): return config # Try base.config config = getattr(base, 'config', None) - if config is not None and hasattr(config, 'tensor_model_parallel_size'): + if config is not None and hasattr(config, + 'tensor_model_parallel_size'): return config - + # Wrapped model: model.model.config if hasattr(model, 'model'): config = getattr(model.model, 'config', None) - if config is not None and hasattr(config, 'tensor_model_parallel_size'): + if config is not None and hasattr(config, + 'tensor_model_parallel_size'): return config - + # Recursive search through modules for name, module in model.named_modules(): config = getattr(module, 'config', None) - if config is not None and hasattr(config, 'tensor_model_parallel_size'): + if config is not None and hasattr(config, + 'tensor_model_parallel_size'): return config - + return None def wrap_model( @@ -402,60 +410,61 @@ def wrap_model( use_distributed_optimizer: bool = True, ) -> Tuple[nn.Module, Optional[torch.optim.Optimizer]]: """Wrap model with Megatron DDP for data parallelism. - + This method behaves differently based on twinkle's execution mode: - + **Local mode (torchrun)**: - Uses Megatron native DDP wrapping - All processes are synchronized by torchrun, so collective ops work - + **Ray mode**: - Currently skips DDP wrapping to avoid deadlocks - Ray's asynchronous actor model makes collective synchronization hard - Each DP replica trains independently - + **Transformers/Accelerate comparison**: - Accelerate's `prepare()` works in Ray because it's a local operation - Megatron DDP's `broadcast_params()` is a collective that needs sync - + Args: model: The Megatron model (already has TP/PP via TransformerConfig). optimizer: Optional optimizer. use_distributed_optimizer: Whether to use distributed optimizer. - + Returns: Tuple of (wrapped_model, optimizer). """ if not self._initialized: self.initialize() - + # Determine execution mode import os twinkle_mode = os.environ.get('TWINKLE_MODE', 'local') - + # Check DP world size dp_group = self.dp_group dp_world_size = 1 if dp_group is not None: dp_world_size = dist.get_world_size(dp_group) - + if dp_world_size <= 1: # No DP needed (single GPU or TP-only) return model, optimizer - + if twinkle_mode == 'ray': # In Ray mode, skip DDP for now due to collective sync issues # TODO: Implement Ray-compatible DDP with barrier synchronization import warnings warnings.warn( - "Skipping Megatron DDP in Ray mode. Each DP replica trains independently. " - "For synchronized training, use torchrun (TWINKLE_MODE=local)." + 'Skipping Megatron DDP in Ray mode. Each DP replica trains independently. ' + 'For synchronized training, use torchrun (TWINKLE_MODE=local).' ) return model, optimizer - + # Local mode (torchrun): Use Megatron native DDP - return self._wrap_with_megatron_ddp(model, optimizer, use_distributed_optimizer) - + return self._wrap_with_megatron_ddp(model, optimizer, + use_distributed_optimizer) + def _wrap_with_megatron_ddp( self, model: nn.Module, @@ -467,17 +476,16 @@ def _wrap_with_megatron_ddp( """ from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.transformer.module import Float16Module - + # Get TransformerConfig from model config = self._get_transformer_config(model) if config is None: import warnings warnings.warn( - "Could not find TransformerConfig. Skipping DDP wrapping. " - "Gradient sync will need to be done manually." - ) + 'Could not find TransformerConfig. Skipping DDP wrapping. ' + 'Gradient sync will need to be done manually.') return model, optimizer - + # Ensure model is on GPU try: model_device = next(model.parameters()).device @@ -486,28 +494,30 @@ def _wrap_with_megatron_ddp( model = model.to(f'cuda:{local_rank}') except StopIteration: pass # No parameters - + # Wrap with Float16Module for mixed precision (like Megatron's get_model) - if (config.fp16 or config.bf16) and not isinstance(model, Float16Module): + if (config.fp16 + or config.bf16) and not isinstance(model, Float16Module): # Check if the inner model (for PEFT) needs wrapping inner_model = model - if hasattr(model, 'base_model') and hasattr(model.base_model, 'model'): + if hasattr(model, 'base_model') and hasattr( + model.base_model, 'model'): inner_model = model.base_model.model - + # Only wrap if not already wrapped if not isinstance(inner_model, Float16Module): # For PEFT models, we can't easily wrap the inner model # Just proceed without Float16Module if not hasattr(model, 'base_model'): model = Float16Module(config, model) - + # Create DDP config ddp_config = DistributedDataParallelConfig( grad_reduce_in_fp32=True, overlap_grad_reduce=False, use_distributed_optimizer=use_distributed_optimizer, ) - + # Wrap with MegatronDDP # TODO: multi-tenant ddp try: @@ -516,34 +526,36 @@ def _wrap_with_megatron_ddp( ddp_config=ddp_config, module=model, ) - + # Broadcast params from data parallel src rank # In torchrun mode, all ranks enter here simultaneously, so this works wrapped_model.broadcast_params() - + return wrapped_model, optimizer - + except Exception as e: import warnings - warnings.warn(f"Failed to wrap with Megatron DDP: {e}. Using unwrapped model.") + warnings.warn( + f'Failed to wrap with Megatron DDP: {e}. Using unwrapped model.' + ) return model, optimizer def unwrap_model(self, model: nn.Module) -> nn.Module: """Unwrap the distributed model to get the base model. - + Args: model: The wrapped model. - + Returns: The unwrapped base model. """ if isinstance(model, MegatronDDP): return model.module - + from torch.nn.parallel import DistributedDataParallel as TorchDDP if isinstance(model, TorchDDP): return model.module - + return model def get_model_config( @@ -560,7 +572,7 @@ def get_model_config( **kwargs, ): """Create a Megatron TransformerConfig. - + Args: hidden_size: Hidden dimension size. num_attention_heads: Number of attention heads. @@ -572,12 +584,12 @@ def get_model_config( num_experts: Number of MoE experts. moe_router_topk: Top-k for MoE routing. **kwargs: Additional config arguments. - + Returns: Megatron TransformerConfig. """ from megatron.core.transformer import TransformerConfig - + config = TransformerConfig( num_layers=num_layers, hidden_size=hidden_size, @@ -595,80 +607,80 @@ def get_model_config( moe_router_topk=moe_router_topk, **kwargs, ) - + return config - + def sync_gradients(self, model: Optional[nn.Module] = None) -> None: """Synchronize gradients across data parallel group. - + For DDP-wrapped models, gradients are synchronized automatically. For non-DDP models (e.g., PEFT models), this performs manual all-reduce. - + Args: model: Optional model to sync gradients for. If None, only barrier. """ if not self._initialized: return - + dp_group = self.dp_group if dp_group is None: return - + dp_size = dist.get_world_size(dp_group) if dp_size <= 1: return - + if model is not None: # Manual gradient synchronization for non-DDP models (e.g., PEFT) self.all_reduce_gradients(model) else: # Just barrier for DDP models dist.barrier(dp_group) - + def all_reduce_gradients(self, model: nn.Module) -> None: """All-reduce gradients of trainable parameters across data parallel group. - + This is used for PEFT/LoRA models that are not wrapped with DDP. Gradients are averaged across all DP ranks. - + Args: model: The model whose gradients to synchronize. """ if not self._initialized: return - + dp_group = self.dp_group if dp_group is None: return - + dp_size = dist.get_world_size(dp_group) if dp_size <= 1: return - + # Collect gradients from trainable parameters grads = [] for param in model.parameters(): if param.requires_grad and param.grad is not None: grads.append(param.grad.data) - + if not grads: return - + # Flatten all gradients into a single tensor for efficient communication # This reduces the number of all-reduce operations flat_grads = torch.cat([g.contiguous().view(-1) for g in grads]) - + # All-reduce and average dist.all_reduce(flat_grads, op=dist.ReduceOp.SUM, group=dp_group) flat_grads.div_(dp_size) - + # Unflatten back to original gradient tensors offset = 0 for grad in grads: numel = grad.numel() grad.copy_(flat_grads[offset:offset + numel].view_as(grad)) offset += numel - + def all_reduce( self, tensor: torch.Tensor, @@ -676,26 +688,26 @@ def all_reduce( group: Optional[dist.ProcessGroup] = None, ) -> torch.Tensor: """All-reduce tensor across specified group. - + Args: tensor: Input tensor. op: Reduce operation. group: Process group (defaults to data parallel group). - + Returns: Reduced tensor. """ if not self._initialized: return tensor - + if group is None: group = self.dp_group - + if group is not None: dist.all_reduce(tensor, op=op, group=group) - + return tensor - + def broadcast( self, tensor: torch.Tensor, @@ -703,29 +715,29 @@ def broadcast( group: Optional[dist.ProcessGroup] = None, ) -> torch.Tensor: """Broadcast tensor from source rank. - + Args: tensor: Input tensor. src: Source rank. group: Process group (defaults to data parallel group). - + Returns: Broadcasted tensor. """ if not self._initialized: return tensor - + if group is None: group = self.dp_group - + if group is not None: dist.broadcast(tensor, src=src, group=group) - + return tensor def get_parallel_info(self) -> Dict[str, Any]: """Get parallelism configuration information. - + Returns: Dict with parallel configuration details. """ @@ -746,10 +758,8 @@ def get_parallel_info(self) -> Dict[str, Any]: 'cp_rank': self.cp_rank, 'ep_rank': self.ep_rank, } - + def __repr__(self) -> str: - return ( - f"MegatronStrategy(tp={self.tp_size}, pp={self.pp_size}, " - f"cp={self.cp_size}, ep={self.ep_size}, dp={self.dp_size}, " - f"sequence_parallel={self.sequence_parallel})" - ) + return (f'MegatronStrategy(tp={self.tp_size}, pp={self.pp_size}, ' + f'cp={self.cp_size}, ep={self.ep_size}, dp={self.dp_size}, ' + f'sequence_parallel={self.sequence_parallel})') diff --git a/tests/megatron/test_multi_tenant_ddp.py b/tests/megatron/test_multi_tenant_ddp.py new file mode 100644 index 00000000..3056b3ee --- /dev/null +++ b/tests/megatron/test_multi_tenant_ddp.py @@ -0,0 +1,181 @@ +""" +Unit tests for Multi-Tenant LoRA DDP. + +Tests: +1. Tenant context (ContextVar) +2. Tenant manager lifecycle +3. Dynamic tenant add/remove +""" + +import threading +import unittest +from unittest.mock import MagicMock, patch + +import torch +import torch.nn as nn + + +class TestTenantContext(unittest.TestCase): + """Tests for tenant_context module.""" + + def setUp(self): + from twinkle.megatron.distributed.tenant_context import set_current_tenant + set_current_tenant(None) + + def test_get_set(self): + from twinkle.megatron.distributed.tenant_context import ( + get_current_tenant, set_current_tenant + ) + + self.assertIsNone(get_current_tenant()) + set_current_tenant("a") + self.assertEqual(get_current_tenant(), "a") + set_current_tenant(None) + self.assertIsNone(get_current_tenant()) + + def test_scope(self): + from twinkle.megatron.distributed.tenant_context import ( + get_current_tenant, tenant_scope + ) + + with tenant_scope("x"): + self.assertEqual(get_current_tenant(), "x") + with tenant_scope("y"): + self.assertEqual(get_current_tenant(), "y") + self.assertEqual(get_current_tenant(), "x") + self.assertIsNone(get_current_tenant()) + + def test_require_tenant(self): + from twinkle.megatron.distributed.tenant_context import ( + require_tenant, tenant_scope + ) + + with self.assertRaises(RuntimeError): + require_tenant() + + with tenant_scope("t"): + self.assertEqual(require_tenant(), "t") + + def test_thread_isolation(self): + from twinkle.megatron.distributed.tenant_context import ( + get_current_tenant, set_current_tenant + ) + + results = {} + + def worker(tid): + set_current_tenant(tid) + import time + time.sleep(0.01) + results[tid] = get_current_tenant() + + threads = [threading.Thread(target=worker, args=(f"t{i}",)) for i in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + for i in range(5): + self.assertEqual(results[f"t{i}"], f"t{i}") + + def test_generate_id(self): + from twinkle.megatron.distributed.tenant_context import generate_tenant_id + + ids = [generate_tenant_id() for _ in range(100)] + self.assertEqual(len(ids), len(set(ids))) + + +class TestTenantManager(unittest.TestCase): + """Tests for TenantManager.""" + + def test_initialize_finalize(self): + from twinkle.megatron.distributed.tenant_manager import TenantManager + + model = nn.Linear(10, 10) + manager = TenantManager(model) + + # Mock PEFT + with patch('twinkle.megatron.distributed.tenant_manager.PEFT_AVAILABLE', False): + # Add fake lora param + lora_param = nn.Parameter(torch.randn(4, 10)) + lora_param.requires_grad = True + model.lora_A = nn.ParameterDict({'test': lora_param}) + + # Need to patch named_parameters + original_named_params = model.named_parameters + def mock_named_params(): + yield 'weight', model.weight + yield 'lora_A.test.lora_A', lora_param + model.named_parameters = mock_named_params + + tid = manager.initialize( + optimizer_kwargs={'lr': 1e-4}, + adapter_name='test', + ) + + self.assertTrue(manager.has(tid)) + self.assertIn(tid, manager.list()) + + state = manager.get(tid) + self.assertEqual(state.adapter_name, 'test') + + manager.finalize(tid) + self.assertFalse(manager.has(tid)) + + def test_callbacks(self): + from twinkle.megatron.distributed.tenant_manager import TenantManager + + model = nn.Linear(10, 10) + manager = TenantManager(model) + + added = [] + removed = [] + + manager.register_add_callback(lambda s: added.append(s.tenant_id)) + manager.register_remove_callback(lambda s: removed.append(s.tenant_id)) + + with patch('twinkle.megatron.distributed.tenant_manager.PEFT_AVAILABLE', False): + lora_param = nn.Parameter(torch.randn(4, 10)) + original_named_params = model.named_parameters + def mock_named_params(): + yield 'lora_A.test.lora_A', lora_param + model.named_parameters = mock_named_params + + tid = manager.initialize(adapter_name='test') + self.assertEqual(added, [tid]) + + manager.finalize(tid) + self.assertEqual(removed, [tid]) + + +class TestMultiTenantDDP(unittest.TestCase): + """Tests for MultiTenantLoRADDP.""" + + @patch('twinkle.megatron.distributed.multi_tenant_ddp.MEGATRON_AVAILABLE', False) + def test_requires_megatron(self): + from twinkle.megatron.distributed.multi_tenant_ddp import MultiTenantLoRADDP + + with self.assertRaises(ImportError): + MultiTenantLoRADDP( + config=MagicMock(), + ddp_config=MagicMock(), + module=nn.Linear(10, 10), + ) + + +class TestMegatronMultiAdapter(unittest.TestCase): + """Tests for MegatronMultiAdapter.""" + + def test_adapter_var(self): + from twinkle.megatron.model.multi_tenant_megatron import MegatronMultiAdapter + + MegatronMultiAdapter._patched = False + + self.assertIsNone(MegatronMultiAdapter.get_current_adapter_name()) + MegatronMultiAdapter.set_current_adapter_name("a") + self.assertEqual(MegatronMultiAdapter.get_current_adapter_name(), "a") + MegatronMultiAdapter.set_current_adapter_name(None) + + +if __name__ == "__main__": + unittest.main() From 61a7f0c0e95565fd2b4082258d1709e6ac661f63 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 19 Jan 2026 14:33:41 +0800 Subject: [PATCH 09/22] wip --- .gitignore | 1 + .../megatron_multi_tenant/__init__.py | 5 + .../megatron/megatron_multi_tenant/client.py | 4 +- .../megatron/megatron_multi_tenant/server.py | 15 +- cookbook/megatron_multi_tenant/server.py | 239 ------ src/twinkle/megatron/distributed/__init__.py | 20 +- .../distributed/clock_cycle_scheduler.py | 598 ++++++++++++++ .../megatron/distributed/tenant_manager.py | 10 +- src/twinkle/model/megatron.py | 29 +- tests/megatron/test_multi_tenant_benchmark.py | 746 ++++++++++++++++++ tests/megatron/test_multi_tenant_ddp.py | 181 ----- tests/megatron/test_multi_tenant_modules.py | 641 +++++++++++++++ 12 files changed, 2054 insertions(+), 435 deletions(-) create mode 100644 cookbook/megatron/megatron_multi_tenant/__init__.py delete mode 100644 cookbook/megatron_multi_tenant/server.py create mode 100644 src/twinkle/megatron/distributed/clock_cycle_scheduler.py create mode 100644 tests/megatron/test_multi_tenant_benchmark.py delete mode 100644 tests/megatron/test_multi_tenant_ddp.py create mode 100644 tests/megatron/test_multi_tenant_modules.py diff --git a/.gitignore b/.gitignore index de8aa916..a380c1a8 100644 --- a/.gitignore +++ b/.gitignore @@ -35,6 +35,7 @@ wheels/ /package /temp MANIFEST +.locks/ # PyInstaller # Usually these files are written by a python script from a template diff --git a/cookbook/megatron/megatron_multi_tenant/__init__.py b/cookbook/megatron/megatron_multi_tenant/__init__.py new file mode 100644 index 00000000..bda8411f --- /dev/null +++ b/cookbook/megatron/megatron_multi_tenant/__init__.py @@ -0,0 +1,5 @@ +# Multi-Tenant Megatron LoRA Training Demo +# +# This directory contains demo code for multi-tenant training: +# - server.py: FastAPI server managing shared base model +# - client.py: Training client for remote training diff --git a/cookbook/megatron/megatron_multi_tenant/client.py b/cookbook/megatron/megatron_multi_tenant/client.py index 9a729e5f..3772a414 100644 --- a/cookbook/megatron/megatron_multi_tenant/client.py +++ b/cookbook/megatron/megatron_multi_tenant/client.py @@ -2,7 +2,9 @@ Multi-Tenant Megatron LoRA Training - Client Example. Simple training loop using remote multi-tenant server. -Inspired by tinker-cookbook's minimal training scripts. + +Usage: + python client.py --server-url http://localhost:8080 """ import logging diff --git a/cookbook/megatron/megatron_multi_tenant/server.py b/cookbook/megatron/megatron_multi_tenant/server.py index 6cfa63e6..d2ca3c3a 100644 --- a/cookbook/megatron/megatron_multi_tenant/server.py +++ b/cookbook/megatron/megatron_multi_tenant/server.py @@ -2,6 +2,9 @@ Multi-Tenant Megatron LoRA Training - Server. Creates a shared base model and provides APIs for multi-tenant training. + +Usage: + python server.py --model-id Qwen/Qwen2.5-7B --tp 2 --port 8080 """ import argparse @@ -149,8 +152,9 @@ def lr_step(self, tenant_id: str): self._heartbeat(tenant_id) self.model.lr_step(tenant_id) - def list_tenants(self) -> List[str]: - return self.model.list_tenants() + def tenant_count(self) -> int: + """Get number of active tenants (does not expose tenant IDs for privacy).""" + return self.model.tenant_count() # ============ FastAPI App ============ @@ -205,9 +209,10 @@ def lr_step(request: Request): server.lr_step(get_tenant(request)) return TenantResponse() - @app.get("/tenants") - def tenants(): - return {"tenants": server.list_tenants()} + @app.get("/stats") + def stats(): + """Server statistics (does not expose tenant IDs for privacy).""" + return {"tenant_count": server.tenant_count()} @app.get("/health") def health(): diff --git a/cookbook/megatron_multi_tenant/server.py b/cookbook/megatron_multi_tenant/server.py deleted file mode 100644 index 45ecd925..00000000 --- a/cookbook/megatron_multi_tenant/server.py +++ /dev/null @@ -1,239 +0,0 @@ -""" -Multi-Tenant Megatron LoRA Training - Server. - -Creates a shared base model and provides APIs for multi-tenant training. -""" - -import argparse -import logging -import threading -import time -from typing import Any, Dict, List, Optional - -import torch -from fastapi import FastAPI, HTTPException, Request -from pydantic import BaseModel - -logger = logging.getLogger(__name__) - - -# ============ Request/Response Models ============ - -class InitializeRequest(BaseModel): - lora_config: Optional[Dict[str, Any]] = None - optimizer_cls: str = "AdamW" - optimizer_kwargs: Optional[Dict[str, Any]] = None - gradient_accumulation_steps: int = 1 - max_grad_norm: float = 1.0 - -class InputsRequest(BaseModel): - inputs: Any - -class TenantResponse(BaseModel): - status: str = "ok" - tenant_id: Optional[str] = None - data: Optional[Any] = None - - -# ============ Server ============ - -class MultiTenantServer: - """Server managing multi-tenant Megatron model.""" - - TIMEOUT = 60 * 30 # 30 min heartbeat timeout - - def __init__(self, model_id: str, tp_size: int = 1): - self.model_id = model_id - self.tp_size = tp_size - self.model = None - self._heartbeats: Dict[str, float] = {} - self._lock = threading.Lock() - - def setup(self): - """Initialize model.""" - from twinkle.megatron.model import ( - MultiTenantMegatronModel, - initialize_megatron_model, - ) - - logger.info(f"Loading model: {self.model_id}") - base_model, config = initialize_megatron_model( - model_id=self.model_id, - tensor_parallel_size=self.tp_size, - ) - - # Freeze base model - for p in base_model.parameters(): - p.requires_grad = False - - self.model = MultiTenantMegatronModel(base_model, config) - logger.info("Server ready") - - # Start heartbeat monitor - threading.Thread(target=self._monitor, daemon=True).start() - - def _monitor(self): - """Cleanup inactive tenants.""" - while True: - time.sleep(60) - now = time.time() - with self._lock: - expired = [t for t, ts in self._heartbeats.items() if now - ts > self.TIMEOUT] - for tid in expired: - logger.warning(f"Tenant {tid} timed out") - try: - self.finalize(tid) - except: - pass - - def _heartbeat(self, tenant_id: str): - with self._lock: - self._heartbeats[tenant_id] = time.time() - - def initialize(self, request: InitializeRequest) -> str: - """Initialize tenant.""" - from peft import LoraConfig - - lora_config = None - if request.lora_config: - lora_config = LoraConfig(**request.lora_config) - - opt_map = {"AdamW": torch.optim.AdamW, "Adam": torch.optim.Adam} - opt_cls = opt_map.get(request.optimizer_cls, torch.optim.AdamW) - - tenant_id = self.model.initialize( - lora_config=lora_config, - optimizer_cls=opt_cls, - optimizer_kwargs=request.optimizer_kwargs, - gradient_accumulation_steps=request.gradient_accumulation_steps, - max_grad_norm=request.max_grad_norm, - ) - - self._heartbeat(tenant_id) - return tenant_id - - def finalize(self, tenant_id: str): - """Finalize tenant.""" - self.model.finalize(tenant_id) - with self._lock: - self._heartbeats.pop(tenant_id, None) - - def forward_backward(self, tenant_id: str, inputs: Any) -> Dict: - """Forward + backward.""" - self._heartbeat(tenant_id) - - with self.model.scope(tenant_id): - output = self.model(inputs) - # Compute loss (simplified - real impl would depend on task) - loss = output.mean() if isinstance(output, torch.Tensor) else torch.tensor(0.0) - self.model.backward(loss) - return {"loss": loss.item()} - - def finish_grad_sync(self, tenant_id: str): - self._heartbeat(tenant_id) - self.model.finish_grad_sync(tenant_id) - - def clip_grad_norm(self, tenant_id: str) -> float: - self._heartbeat(tenant_id) - return self.model.clip_grad_norm(tenant_id=tenant_id).item() - - def step(self, tenant_id: str): - self._heartbeat(tenant_id) - self.model.step(tenant_id) - - def zero_grad(self, tenant_id: str): - self._heartbeat(tenant_id) - self.model.zero_grad(tenant_id) - - def lr_step(self, tenant_id: str): - self._heartbeat(tenant_id) - self.model.lr_step(tenant_id) - - def tenant_count(self) -> int: - """Get number of active tenants (does not expose tenant IDs).""" - return self.model.tenant_count() - - -# ============ FastAPI App ============ - -def create_app(server: MultiTenantServer) -> FastAPI: - """Create FastAPI app.""" - app = FastAPI(title="Multi-Tenant Megatron Server") - - def get_tenant(request: Request) -> str: - tid = request.headers.get("X-Tenant-ID") - if not tid: - raise HTTPException(400, "Missing X-Tenant-ID") - return tid - - @app.post("/initialize", response_model=TenantResponse) - def initialize(body: InitializeRequest): - tid = server.initialize(body) - return TenantResponse(tenant_id=tid) - - @app.post("/finalize", response_model=TenantResponse) - def finalize(request: Request): - server.finalize(get_tenant(request)) - return TenantResponse() - - @app.post("/forward_backward", response_model=TenantResponse) - def forward_backward(request: Request, body: InputsRequest): - data = server.forward_backward(get_tenant(request), body.inputs) - return TenantResponse(data=data) - - @app.post("/finish_grad_sync", response_model=TenantResponse) - def finish_grad_sync(request: Request): - server.finish_grad_sync(get_tenant(request)) - return TenantResponse() - - @app.post("/clip_grad_norm", response_model=TenantResponse) - def clip_grad_norm(request: Request): - norm = server.clip_grad_norm(get_tenant(request)) - return TenantResponse(data=norm) - - @app.post("/step", response_model=TenantResponse) - def step(request: Request): - server.step(get_tenant(request)) - return TenantResponse() - - @app.post("/zero_grad", response_model=TenantResponse) - def zero_grad(request: Request): - server.zero_grad(get_tenant(request)) - return TenantResponse() - - @app.post("/lr_step", response_model=TenantResponse) - def lr_step(request: Request): - server.lr_step(get_tenant(request)) - return TenantResponse() - - @app.get("/stats") - def stats(): - """Server statistics (does not expose tenant IDs for privacy).""" - return {"tenant_count": server.tenant_count()} - - @app.get("/health") - def health(): - return {"status": "healthy"} - - return app - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--model-id", required=True) - parser.add_argument("--tp", type=int, default=1) - parser.add_argument("--host", default="0.0.0.0") - parser.add_argument("--port", type=int, default=8080) - args = parser.parse_args() - - logging.basicConfig(level=logging.INFO) - - server = MultiTenantServer(args.model_id, args.tp) - server.setup() - - import uvicorn - uvicorn.run(create_app(server), host=args.host, port=args.port) - - -if __name__ == "__main__": - main() diff --git a/src/twinkle/megatron/distributed/__init__.py b/src/twinkle/megatron/distributed/__init__.py index a1defbab..364bd702 100644 --- a/src/twinkle/megatron/distributed/__init__.py +++ b/src/twinkle/megatron/distributed/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) twinkle authors. All rights reserved. """ +[WIP] Distributed training utilities for multi-tenant Megatron LoRA. Core components: @@ -13,6 +14,15 @@ get_current_tenant, require_tenant, set_current_tenant, tenant_scope) from .tenant_manager import TenantManager, TenantState +from .clock_cycle_scheduler import ( + ClockCycleScheduler, + ClockCycleTrainingClient, + CycleStats, + RequestType, + TrainingRequest, + ModelInterfaceError, + validate_model_interface, +) __all__ = [ # Context @@ -25,7 +35,15 @@ # Manager 'TenantManager', 'TenantState', - # DDP + # DDP (Twinkle mode) 'MultiTenantLoRADDP', 'TenantDDPState', + # Clock Cycle Scheduler + 'ClockCycleScheduler', + 'ClockCycleTrainingClient', + 'CycleStats', + 'RequestType', + 'TrainingRequest', + 'ModelInterfaceError', + 'validate_model_interface', ] diff --git a/src/twinkle/megatron/distributed/clock_cycle_scheduler.py b/src/twinkle/megatron/distributed/clock_cycle_scheduler.py new file mode 100644 index 00000000..b1ae11b9 --- /dev/null +++ b/src/twinkle/megatron/distributed/clock_cycle_scheduler.py @@ -0,0 +1,598 @@ +# Copyright (c) twinkle authors. All rights reserved. +""" +Clock Cycle Scheduler for multi-tenant training. + +This module implements a time-sharing scheduler that batches requests +from multiple tenants into fixed clock cycles. + +## Key Concepts + +- Clock Cycle: Fixed time interval where all pending requests are processed +- Request Queue: Collects requests between cycles +- Batched Grad Sync: One communication round for all tenants (efficient) +- Gradient Isolation: Each tenant has separate LoRA params, no gradient overwrite +""" + +import logging +import threading +import time +from collections import defaultdict +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Tuple +from concurrent.futures import Future + +import torch +import torch.nn as nn +import torch.distributed as dist + +logger = logging.getLogger(__name__) + + +# ============ Request Types ============ + +class RequestType(Enum): + """Type of training request.""" + FORWARD_BACKWARD = "forward_backward" + OPTIM_STEP = "optim_step" + ZERO_GRAD = "zero_grad" + + +@dataclass +class TrainingRequest: + """A training request from a tenant.""" + tenant_id: str + request_type: RequestType + inputs: Any = None + labels: Any = None + kwargs: Dict[str, Any] = field(default_factory=dict) + future: Optional[Future] = None + submitted_at: float = field(default_factory=time.time) + + +# ============ Cycle Statistics ============ + +@dataclass +class CycleStats: + """Statistics for a clock cycle.""" + cycle_id: int + start_time: float + end_time: float + num_tenants: int + num_requests: int + total_samples: int + forward_time: float + backward_time: float + grad_sync_time: float + optim_step_time: float + + @property + def duration(self) -> float: + return self.end_time - self.start_time + + @property + def gpu_active_time(self) -> float: + return self.forward_time + self.backward_time + self.optim_step_time + + @property + def gpu_utilization(self) -> float: + if self.duration > 0: + return self.gpu_active_time / self.duration + return 0.0 + + @property + def samples_per_second(self) -> float: + if self.duration > 0: + return self.total_samples / self.duration + return 0.0 + + +# ============ Model Interface Requirements ============ + +class ModelInterfaceError(Exception): + """Raised when model doesn't implement required interface.""" + pass + + +def validate_model_interface(model: nn.Module) -> None: + """ + Validate that model implements the required interface. + + Required methods: + - scope(tenant_id) -> context manager + - zero_grad(tenant_id) -> None + - step(tenant_id) -> None + - __call__(inputs) -> output (forward) + + Optional methods: + - clip_grad_norm(tenant_id, max_norm) -> None + - finish_grad_sync(tenant_id) -> None + - finish_grad_sync_batched(tenant_ids) -> None + """ + required = ['scope', 'zero_grad', 'step'] + missing = [m for m in required if not hasattr(model, m)] + + if missing: + raise ModelInterfaceError( + f"Model must implement: {required}. Missing: {missing}" + ) + + +# ============ Gradient Synchronization ============ + +class GradientSynchronizer: + """ + Handles gradient synchronization for multiple tenants. + + For distributed training, this batches gradient communication + to reduce the number of NCCL calls. + """ + + def __init__(self, model: nn.Module): + self.model = model + + def sync_individual(self, tenant_id: str) -> float: + """Synchronize gradients for a single tenant.""" + if not dist.is_initialized(): + return 0.0 + + t0 = time.time() + + if hasattr(self.model, 'finish_grad_sync'): + self.model.finish_grad_sync(tenant_id) + + return time.time() - t0 + + def sync_batched(self, tenant_ids: List[str]) -> float: + """ + Synchronize gradients for multiple tenants. + + Uses batched sync if model supports it, otherwise falls back + to individual sync. + """ + if not tenant_ids: + return 0.0 + + t0 = time.time() + + if hasattr(self.model, 'finish_grad_sync_batched'): + # Optimized: one call for all tenants + self.model.finish_grad_sync_batched(tenant_ids) + elif dist.is_initialized(): + # Fallback: sync each tenant individually + for tenant_id in tenant_ids: + self.sync_individual(tenant_id) + + return time.time() - t0 + + +# ============ Clock Cycle Scheduler ============ + +class ClockCycleScheduler: + """ + Clock cycle scheduler for multi-tenant training. + + Collects requests from multiple tenants and executes them in batched + clock cycles. While computation is per-tenant serial (due to LoRA + architecture), communication is batched for efficiency. + + ## Benefits + + 1. **Unified Scheduling**: All tenants processed in fixed cycles + 2. **Batched Communication**: One grad sync round for all tenants + 3. **Fair Scheduling**: All pending requests processed together + 4. **Predictable Latency**: Fixed cycle interval + + ## Usage + + ```python + scheduler = ClockCycleScheduler(model, cycle_interval_ms=100) + scheduler.start() + + # From multiple clients + future1 = scheduler.submit_forward_backward('tenant_a', inputs_a) + future2 = scheduler.submit_forward_backward('tenant_b', inputs_b) + + result1 = future1.result() + result2 = future2.result() + + scheduler.stop() + ``` + """ + + def __init__( + self, + model: nn.Module, + cycle_interval_ms: float = 100.0, + loss_fn: Optional[Callable] = None, + ): + """ + Initialize the scheduler. + + Args: + model: The multi-tenant model. Must implement: + - scope(tenant_id) -> context manager + - zero_grad(tenant_id) -> None + - step(tenant_id) -> None + - __call__(inputs) -> output + + cycle_interval_ms: Clock cycle interval in milliseconds. + loss_fn: Loss function (output, labels) -> loss. + Default: output.mean() + + Raises: + ModelInterfaceError: If model doesn't implement required methods. + """ + # Validate model interface + validate_model_interface(model) + + self.model = model + self.cycle_interval = cycle_interval_ms / 1000.0 + self.loss_fn = loss_fn or self._default_loss_fn + + # Gradient synchronizer + self._grad_sync = GradientSynchronizer(model) + + # Request queue (thread-safe) + self._queue_lock = threading.Lock() + self._request_queue: Dict[str, List[TrainingRequest]] = defaultdict(list) + + # Cycle management + self._running = False + self._cycle_thread: Optional[threading.Thread] = None + self._current_cycle_id = 0 + + # Statistics + self._stats: List[CycleStats] = [] + self._stats_lock = threading.Lock() + + def _default_loss_fn(self, output: torch.Tensor, labels: Any) -> torch.Tensor: + """Default loss function (mean of output).""" + if isinstance(output, torch.Tensor): + return output.mean() + raise ValueError(f"Cannot compute loss on {type(output)}, provide loss_fn") + + def start(self): + """Start the clock cycle loop.""" + if self._running: + return + + self._running = True + self._cycle_thread = threading.Thread( + target=self._cycle_loop, + name="ClockCycleLoop", + daemon=True, + ) + self._cycle_thread.start() + logger.info(f"Clock cycle scheduler started (interval={self.cycle_interval*1000:.0f}ms)") + + def stop(self): + """Stop the clock cycle loop.""" + self._running = False + if self._cycle_thread: + self._cycle_thread.join(timeout=5.0) + self._cycle_thread = None + logger.info("Clock cycle scheduler stopped") + + def submit_forward_backward( + self, + tenant_id: str, + inputs: Any, + labels: Any = None, + **kwargs, + ) -> Future: + """ + Submit a forward-backward request. + + Returns immediately with a Future containing the result. + """ + future = Future() + request = TrainingRequest( + tenant_id=tenant_id, + request_type=RequestType.FORWARD_BACKWARD, + inputs=inputs, + labels=labels, + kwargs=kwargs, + future=future, + ) + + with self._queue_lock: + self._request_queue[tenant_id].append(request) + + return future + + def submit_optim_step(self, tenant_id: str, **kwargs) -> Future: + """Submit an optimizer step request.""" + future = Future() + request = TrainingRequest( + tenant_id=tenant_id, + request_type=RequestType.OPTIM_STEP, + kwargs=kwargs, + future=future, + ) + + with self._queue_lock: + self._request_queue[tenant_id].append(request) + + return future + + def submit_zero_grad(self, tenant_id: str) -> Future: + """Submit a zero_grad request.""" + future = Future() + request = TrainingRequest( + tenant_id=tenant_id, + request_type=RequestType.ZERO_GRAD, + future=future, + ) + + with self._queue_lock: + self._request_queue[tenant_id].append(request) + + return future + + def _cycle_loop(self): + """Main clock cycle loop.""" + while self._running: + cycle_start = time.time() + + # Collect pending requests (deep copy for thread safety) + with self._queue_lock: + pending = { + tenant_id: list(reqs) + for tenant_id, reqs in self._request_queue.items() + if reqs + } + self._request_queue.clear() + + # Execute cycle if there are requests + if pending: + self._execute_cycle(pending) + + # Wait for next cycle + elapsed = time.time() - cycle_start + sleep_time = max(0, self.cycle_interval - elapsed) + if sleep_time > 0: + time.sleep(sleep_time) + + def _execute_cycle(self, requests: Dict[str, List[TrainingRequest]]): + """ + Execute one clock cycle. + + Phases: + 1. Forward-backward for each tenant (serial due to LoRA architecture) + 2. Batched gradient synchronization (efficient) + 3. Optimizer step for each tenant + + Note: Forward-backward is serial per tenant because each tenant's + LoRA weights are embedded in every layer. There's no way to "merge" + computation across tenants for Transformer models. + + However, gradient sync is batched, reducing communication overhead. + """ + cycle_start = time.time() + self._current_cycle_id += 1 + cycle_id = self._current_cycle_id + + # Group requests by type + fwd_bwd_reqs: Dict[str, TrainingRequest] = {} + optim_step_reqs: Dict[str, TrainingRequest] = {} + zero_grad_reqs: Dict[str, TrainingRequest] = {} + + for tenant_id, reqs in requests.items(): + for req in reqs: + if req.request_type == RequestType.FORWARD_BACKWARD: + fwd_bwd_reqs[tenant_id] = req + elif req.request_type == RequestType.OPTIM_STEP: + optim_step_reqs[tenant_id] = req + elif req.request_type == RequestType.ZERO_GRAD: + zero_grad_reqs[tenant_id] = req + + num_tenants = len(set(fwd_bwd_reqs.keys()) | set(optim_step_reqs.keys())) + num_requests = sum(len(reqs) for reqs in requests.values()) + + logger.debug(f"Cycle {cycle_id}: {num_tenants} tenants, {num_requests} requests") + + # Tracking + successful_tenants = [] + failed_tenants = [] + total_samples = 0 + forward_time = 0.0 + backward_time = 0.0 + grad_sync_time = 0.0 + optim_step_time = 0.0 + + try: + # ============ PHASE 1: Forward-Backward (per tenant) ============ + # Each tenant's forward-backward is independent because: + # 1. Each tenant has separate LoRA parameters + # 2. Gradients accumulate to each tenant's own LoRA params + # 3. No gradient overwrite between tenants + + for tenant_id, req in fwd_bwd_reqs.items(): + try: + inputs = req.inputs + labels = req.labels + + # Get batch size for stats + if isinstance(inputs, torch.Tensor): + batch_size = inputs.size(0) + elif isinstance(inputs, dict) and 'input_ids' in inputs: + batch_size = inputs['input_ids'].size(0) + else: + batch_size = 1 + + total_samples += batch_size + + with self.model.scope(tenant_id): + # Forward + t0 = time.time() + output = self.model(inputs) + forward_time += time.time() - t0 + + # Loss + loss = self.loss_fn(output, labels) + + # Backward + # Note: Each tenant's LoRA params are separate, + # so loss.backward() accumulates gradients to + # this tenant's params only - no overwrite + t0 = time.time() + loss.backward() + backward_time += time.time() - t0 + + # Record result + result = { + 'loss': loss.item() if hasattr(loss, 'item') else float(loss), + 'cycle_id': cycle_id, + 'batch_size': batch_size, + } + req.future.set_result(result) + successful_tenants.append(tenant_id) + + except Exception as e: + logger.error(f"Tenant {tenant_id} forward-backward failed: {e}") + failed_tenants.append(tenant_id) + req.future.set_exception(e) + + # Clean up failed tenant's gradient state + try: + self.model.zero_grad(tenant_id) + except Exception: + pass + + # ============ PHASE 2: Batched Gradient Sync ============ + # This is where we get efficiency: one sync round for all tenants + if successful_tenants: + grad_sync_time = self._grad_sync.sync_batched(successful_tenants) + + # ============ PHASE 3: Optimizer Step ============ + for tenant_id, req in optim_step_reqs.items(): + try: + t0 = time.time() + + # Clip gradients (optional) + if hasattr(self.model, 'clip_grad_norm'): + self.model.clip_grad_norm(tenant_id=tenant_id) + + # Optimizer step + self.model.step(tenant_id) + + # Zero grad after step + self.model.zero_grad(tenant_id) + + optim_step_time += time.time() - t0 + req.future.set_result({'cycle_id': cycle_id}) + + except Exception as e: + logger.error(f"Tenant {tenant_id} optimizer step failed: {e}") + req.future.set_exception(e) + + # ============ PHASE 4: Standalone zero_grad ============ + for tenant_id, req in zero_grad_reqs.items(): + if tenant_id not in optim_step_reqs: + try: + self.model.zero_grad(tenant_id) + req.future.set_result(None) + except Exception as e: + req.future.set_exception(e) + + except Exception as e: + logger.exception(f"Cycle {cycle_id} failed: {e}") + for reqs in requests.values(): + for req in reqs: + if not req.future.done(): + req.future.set_exception(e) + + # Record stats + cycle_end = time.time() + stats = CycleStats( + cycle_id=cycle_id, + start_time=cycle_start, + end_time=cycle_end, + num_tenants=num_tenants, + num_requests=num_requests, + total_samples=total_samples, + forward_time=forward_time, + backward_time=backward_time, + grad_sync_time=grad_sync_time, + optim_step_time=optim_step_time, + ) + + with self._stats_lock: + self._stats.append(stats) + + logger.debug( + f"Cycle {cycle_id} completed: duration={stats.duration*1000:.1f}ms, " + f"tenants={num_tenants}, samples={total_samples}" + ) + + def get_stats(self) -> List[CycleStats]: + """Get all cycle statistics.""" + with self._stats_lock: + return list(self._stats) + + def get_summary_stats(self) -> Dict[str, float]: + """Get summary statistics.""" + with self._stats_lock: + if not self._stats: + return {} + + total_cycles = len(self._stats) + total_duration = sum(s.duration for s in self._stats) + total_samples = sum(s.total_samples for s in self._stats) + total_forward = sum(s.forward_time for s in self._stats) + total_backward = sum(s.backward_time for s in self._stats) + total_sync = sum(s.grad_sync_time for s in self._stats) + total_optim = sum(s.optim_step_time for s in self._stats) + total_gpu_time = total_forward + total_backward + total_optim + + return { + 'total_cycles': total_cycles, + 'total_duration': total_duration, + 'total_samples': total_samples, + 'total_gpu_time': total_gpu_time, + 'total_comm_time': total_sync, + 'avg_cycle_duration': total_duration / total_cycles, + 'gpu_utilization': total_gpu_time / total_duration if total_duration > 0 else 0, + 'throughput_samples_per_sec': total_samples / total_duration if total_duration > 0 else 0, + } + + +# ============ Training Client ============ + +class ClockCycleTrainingClient: + """ + Client for clock cycle scheduler. + + Provides a simple API for submitting training requests. + """ + + def __init__(self, scheduler: ClockCycleScheduler, tenant_id: str): + self.scheduler = scheduler + self.tenant_id = tenant_id + + def forward_backward(self, inputs: Any, labels: Any = None) -> Future: + """Submit forward-backward (returns Future).""" + return self.scheduler.submit_forward_backward(self.tenant_id, inputs, labels) + + def optim_step(self) -> Future: + """Submit optimizer step (returns Future).""" + return self.scheduler.submit_optim_step(self.tenant_id) + + def zero_grad(self) -> Future: + """Submit zero_grad (returns Future).""" + return self.scheduler.submit_zero_grad(self.tenant_id) + + def train_step(self, inputs: Any, labels: Any = None) -> Dict[str, Any]: + """ + Execute a complete training step (blocking). + + Submits forward_backward and optim_step in the same cycle. + """ + fwd_future = self.forward_backward(inputs, labels) + opt_future = self.optim_step() + + result = fwd_future.result() + opt_future.result() + + return result diff --git a/src/twinkle/megatron/distributed/tenant_manager.py b/src/twinkle/megatron/distributed/tenant_manager.py index f378e0e2..260709a4 100644 --- a/src/twinkle/megatron/distributed/tenant_manager.py +++ b/src/twinkle/megatron/distributed/tenant_manager.py @@ -2,7 +2,12 @@ """ Tenant Manager for multi-tenant LoRA training. -This module provides tenant lifecycle management +This module provides tenant lifecycle management including: +- Tenant registration/deregistration +- LoRA adapter management (via PEFT) +- Optimizer/scheduler creation +- Tenant context switching + """ import logging @@ -263,6 +268,3 @@ def has(self, tenant_id: str) -> bool: def count(self) -> int: """Number of tenants (does not expose tenant IDs for privacy).""" return len(self._tenants) - - # Note: list() method intentionally not exposed to clients to prevent - # information leakage. Only server-side code should enumerate tenants. diff --git a/src/twinkle/model/megatron.py b/src/twinkle/model/megatron.py index 08f3af06..f9171e8a 100644 --- a/src/twinkle/model/megatron.py +++ b/src/twinkle/model/megatron.py @@ -591,6 +591,18 @@ def forward_backward(self, if len(microbatch_list) > 1: num_microbatches = len(microbatch_list) + # Helper to convert list/numpy to tensor + def ensure_tensor(value): + if value is None: + return None + if isinstance(value, torch.Tensor): + return value + if isinstance(value, list): + return torch.tensor(value) + if hasattr(value, '__array__'): # numpy array + return torch.from_numpy(value) + return value + # Process each microbatch processed_batches = [] for batch in microbatch_list: @@ -604,6 +616,12 @@ def forward_backward(self, if processor is not None: batch = processor(batch) + # Ensure all tensor fields are proper tensors + if isinstance(batch, dict): + for key in ['input_ids', 'attention_mask', 'labels', 'position_ids']: + if key in batch: + batch[key] = ensure_tensor(batch[key]) + processed_batches.append(batch) # Get first batch for shape info (all batches should have same shape) @@ -614,10 +632,13 @@ def forward_backward(self, cp_rank = mpu.get_context_parallel_rank() if cp_size > 1 else 0 # Get sequence length and batch size from first batch - original_seq_length = first_batch['input_ids'].shape[ - 1] if 'input_ids' in first_batch else 1 - micro_batch_size = first_batch['input_ids'].shape[ - 0] if 'input_ids' in first_batch else 1 + input_ids = first_batch.get('input_ids') + if input_ids is not None and isinstance(input_ids, torch.Tensor): + original_seq_length = input_ids.shape[1] if input_ids.dim() > 1 else input_ids.shape[0] + micro_batch_size = input_ids.shape[0] if input_ids.dim() > 1 else 1 + else: + original_seq_length = 1 + micro_batch_size = 1 # For CP > 1, pad seq_length to be divisible by 2*cp_size if cp_size > 1: diff --git a/tests/megatron/test_multi_tenant_benchmark.py b/tests/megatron/test_multi_tenant_benchmark.py new file mode 100644 index 00000000..b9dcb3f8 --- /dev/null +++ b/tests/megatron/test_multi_tenant_benchmark.py @@ -0,0 +1,746 @@ +#!/usr/bin/env python +""" +Benchmark comparison of multi-tenant architectures. + +Compares: +1. Twinkle Mode: Per-tenant serial execution (independent calls) +2. Clock Cycle Mode: Unified scheduling + batched communication + +Key insight: For LLM+LoRA, batch merging is NOT possible because LoRA +weights are embedded in every layer. The benefit of Clock Cycle is +batched communication, not merged computation. + +Metrics: +- Throughput (samples/second) +- Latency (per step) +- Communication efficiency (N syncs vs 1 sync) +""" + +import argparse +import logging +import threading +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn as nn + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + + +# ============ Mock Model with Tinker-compatible API ============ + +class MockBaseModel(nn.Module): + """Mock base model (shared across tenants).""" + + def __init__(self, hidden_size: int, num_layers: int, simulate_ms: float): + super().__init__() + self.hidden_size = hidden_size + self.num_layers = num_layers + self.simulate_ms = simulate_ms + + # Create base layers (frozen) + self.layers = nn.ModuleList([ + nn.Linear(hidden_size, hidden_size, bias=False) + for _ in range(num_layers) + ]) + + for layer in self.layers: + layer.weight.requires_grad = False + + # Stats + self.forward_calls = 0 + self.total_samples = 0 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass - simulates compute time.""" + batch_size = x.size(0) + self.forward_calls += 1 + self.total_samples += batch_size + + # Simulate compute (scales slightly with batch size) + # Key insight: one large batch is more efficient than N small batches + time.sleep(self.simulate_ms / 1000.0 * (1 + 0.1 * (batch_size / 8))) + + for layer in self.layers: + x = layer(x) + x = torch.relu(x) + return x + + def reset_stats(self): + self.forward_calls = 0 + self.total_samples = 0 + + +class MockLoRAAdapter(nn.Module): + """Mock LoRA adapter for a single tenant.""" + + def __init__(self, hidden_size: int, rank: int = 8, simulate_ms: float = 1.0): + super().__init__() + self.hidden_size = hidden_size + self.rank = rank + self.simulate_ms = simulate_ms + + self.lora_A = nn.Parameter(torch.randn(rank, hidden_size) * 0.01) + self.lora_B = nn.Parameter(torch.zeros(hidden_size, rank)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply LoRA transformation.""" + time.sleep(self.simulate_ms / 1000.0) + return x + x @ self.lora_A.T @ self.lora_B.T + + +class MockMultiTenantModel(nn.Module): + """ + Mock multi-tenant model with Tinker-compatible API. + + Supports: + - base_forward(): Run base model only (for batch merging) + - apply_lora(): Apply per-tenant LoRA + - scope(): Context manager for tenant selection + - finish_grad_sync_batched(): Batched gradient sync + """ + + def __init__( + self, + hidden_size: int = 256, + num_layers: int = 4, + lora_rank: int = 8, + base_model_ms: float = 10.0, + lora_ms: float = 2.0, + comm_ms: float = 5.0, + ): + super().__init__() + self.hidden_size = hidden_size + self.num_layers = num_layers + self.lora_rank = lora_rank + self.base_model_ms = base_model_ms + self.lora_ms = lora_ms + self.comm_ms = comm_ms + + # Base model (shared) + self.base_model = MockBaseModel(hidden_size, num_layers, base_model_ms) + + # Per-tenant adapters + self._adapters: Dict[str, MockLoRAAdapter] = nn.ModuleDict() + self._optimizers: Dict[str, torch.optim.Optimizer] = {} + + # Current tenant context + self._current_tenant: Optional[str] = None + self._lock = threading.Lock() + + # Stats + self._compute_time = 0.0 + self._comm_time = 0.0 + + def initialize( + self, + tenant_id: Optional[str] = None, + optimizer_kwargs: Optional[Dict] = None, + **kwargs, + ) -> str: + """Initialize a tenant.""" + import uuid + tenant_id = tenant_id or str(uuid.uuid4())[:8] + + with self._lock: + if tenant_id in self._adapters: + raise ValueError(f"Tenant {tenant_id} exists") + + # Create adapter + adapter = MockLoRAAdapter(self.hidden_size, self.lora_rank, self.lora_ms) + self._adapters[tenant_id] = adapter + + # Create optimizer + opt_kwargs = optimizer_kwargs or {'lr': 1e-4} + self._optimizers[tenant_id] = torch.optim.AdamW(adapter.parameters(), **opt_kwargs) + + self._current_tenant = tenant_id + + return tenant_id + + def finalize(self, tenant_id: Optional[str] = None): + """Finalize a tenant.""" + tenant_id = tenant_id or self._current_tenant + with self._lock: + if tenant_id in self._adapters: + del self._adapters[tenant_id] + del self._optimizers[tenant_id] + + @contextmanager + def scope(self, tenant_id: str): + """Context manager for tenant scope.""" + old = self._current_tenant + self._current_tenant = tenant_id + try: + yield + finally: + self._current_tenant = old + + def base_forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Run ONLY the base model (for batch merging). + + This is the key for Tinker efficiency - call once for all tenants. + """ + return self.base_model(x) + + def apply_lora(self, features: torch.Tensor, tenant_id: Optional[str] = None) -> torch.Tensor: + """Apply per-tenant LoRA to pre-computed features.""" + tenant_id = tenant_id or self._current_tenant + if tenant_id and tenant_id in self._adapters: + return self._adapters[tenant_id](features) + return features + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Full forward pass (base + current tenant's LoRA).""" + features = self.base_model(x) + return self.apply_lora(features) + + def backward(self, loss: torch.Tensor, tenant_id: Optional[str] = None): + """Backward pass.""" + t0 = time.time() + loss.backward() + self._compute_time += time.time() - t0 + + def finish_grad_sync(self, tenant_id: Optional[str] = None): + """Finish gradient sync for single tenant.""" + time.sleep(self.comm_ms / 1000.0) + self._comm_time += self.comm_ms / 1000.0 + + def finish_grad_sync_batched(self, tenant_ids: List[str]): + """ + Batched gradient sync (Tinker optimization). + + One all-reduce for all tenants instead of N all-reduces. + """ + # Simulate batched communication (more efficient than N separate calls) + # Overhead is sub-linear with number of tenants + batched_time = self.comm_ms / 1000.0 * (1 + 0.1 * len(tenant_ids)) + time.sleep(batched_time) + self._comm_time += batched_time + + def clip_grad_norm(self, tenant_id: Optional[str] = None, max_norm: float = 1.0): + """Clip gradients.""" + tenant_id = tenant_id or self._current_tenant + if tenant_id in self._adapters: + torch.nn.utils.clip_grad_norm_( + self._adapters[tenant_id].parameters(), max_norm + ) + + def step(self, tenant_id: Optional[str] = None): + """Optimizer step.""" + tenant_id = tenant_id or self._current_tenant + if tenant_id in self._optimizers: + self._optimizers[tenant_id].step() + + def zero_grad(self, tenant_id: Optional[str] = None): + """Zero gradients.""" + tenant_id = tenant_id or self._current_tenant + if tenant_id in self._optimizers: + self._optimizers[tenant_id].zero_grad(set_to_none=True) + + def get_stats(self) -> Dict[str, Any]: + return { + 'compute_time': self._compute_time, + 'comm_time': self._comm_time, + 'base_model_forward_calls': self.base_model.forward_calls, + 'base_model_total_samples': self.base_model.total_samples, + } + + def reset_stats(self): + self._compute_time = 0.0 + self._comm_time = 0.0 + self.base_model.reset_stats() + + def tenant_count(self) -> int: + return len(self._adapters) + + def has_tenant(self, tenant_id: str) -> bool: + return tenant_id in self._adapters + + +# ============ Benchmark Classes ============ + +@dataclass +class BenchmarkConfig: + """Configuration for benchmark.""" + num_tenants: int = 4 + steps_per_tenant: int = 10 + batch_size_per_tenant: int = 8 + hidden_size: int = 256 + base_model_ms: float = 10.0 # Base model forward time + lora_ms: float = 2.0 # LoRA forward time per tenant + comm_ms: float = 5.0 # Communication time + clock_cycle_interval_ms: float = 50.0 + + +@dataclass +class BenchmarkResult: + """Result of a benchmark run.""" + mode: str + total_time: float + total_steps: int + total_samples: int + throughput_steps: float # steps/second + throughput_samples: float # samples/second + avg_latency: float # seconds per step + base_model_calls: int # Number of base model forward calls + base_model_samples: int # Total samples processed by base model + compute_time: float + comm_time: float + gpu_utilization: float # compute_time / total_time + + def __str__(self): + return ( + f"{self.mode}:\n" + f" Total time: {self.total_time:.2f}s\n" + f" Total steps: {self.total_steps} ({self.total_samples} samples)\n" + f" Throughput: {self.throughput_steps:.2f} steps/s, {self.throughput_samples:.2f} samples/s\n" + f" Avg latency: {self.avg_latency*1000:.2f} ms/step\n" + f" Base model calls: {self.base_model_calls} (samples: {self.base_model_samples})\n" + f" GPU utilization: {self.gpu_utilization*100:.1f}%\n" + ) + + +class TwinkleBenchmark: + """ + Benchmark for Twinkle mode (per-tenant serial execution). + + In this mode: + - Each tenant's request is processed separately + - Base model is called N times (once per tenant) + - Gradient sync is done N times + """ + + def __init__(self, config: BenchmarkConfig): + self.config = config + self.model = MockMultiTenantModel( + hidden_size=config.hidden_size, + base_model_ms=config.base_model_ms, + lora_ms=config.lora_ms, + comm_ms=config.comm_ms, + ) + + def run(self) -> BenchmarkResult: + """Run the benchmark.""" + logger.info("Running Twinkle mode benchmark...") + + # Initialize tenants + tenant_ids = [] + for i in range(self.config.num_tenants): + tid = self.model.initialize(tenant_id=f"tenant_{i}") + tenant_ids.append(tid) + + self.model.reset_stats() + + # Create dummy input + x = torch.randn(self.config.batch_size_per_tenant, self.config.hidden_size) + + total_steps = 0 + total_samples = 0 + step_latencies = [] + + start_time = time.time() + + # Training loop - serial per tenant + for step in range(self.config.steps_per_tenant): + for tenant_id in tenant_ids: + step_start = time.time() + + with self.model.scope(tenant_id): + self.model.zero_grad(tenant_id) + output = self.model(x) # Full forward (base + LoRA) + loss = output.mean() + self.model.backward(loss, tenant_id) + self.model.finish_grad_sync(tenant_id) # Individual sync + self.model.clip_grad_norm(tenant_id) + self.model.step(tenant_id) + + step_latencies.append(time.time() - step_start) + total_steps += 1 + total_samples += self.config.batch_size_per_tenant + + total_time = time.time() - start_time + + # Cleanup + for tid in tenant_ids: + self.model.finalize(tid) + + # Calculate metrics + stats = self.model.get_stats() + + return BenchmarkResult( + mode="Twinkle (Serial)", + total_time=total_time, + total_steps=total_steps, + total_samples=total_samples, + throughput_steps=total_steps / total_time, + throughput_samples=total_samples / total_time, + avg_latency=sum(step_latencies) / len(step_latencies), + base_model_calls=stats['base_model_forward_calls'], + base_model_samples=stats['base_model_total_samples'], + compute_time=stats['compute_time'], + comm_time=stats['comm_time'], + gpu_utilization=stats['compute_time'] / total_time, + ) + + +class TinkerBenchmark: + """ + Benchmark for Tinker mode (clock cycle with batch merging). + + In this mode: + - Multiple tenants' requests are batched in each cycle + - Base model is called ONCE per cycle (with merged batch) + - Gradient sync is done ONCE per cycle (batched) + """ + + def __init__(self, config: BenchmarkConfig): + self.config = config + self.model = MockMultiTenantModel( + hidden_size=config.hidden_size, + base_model_ms=config.base_model_ms, + lora_ms=config.lora_ms, + comm_ms=config.comm_ms, + ) + + def run(self) -> BenchmarkResult: + """Run the benchmark.""" + from twinkle.megatron.distributed.clock_cycle_scheduler import ( + ClockCycleScheduler, + ClockCycleTrainingClient, + ) + + logger.info("Running Tinker mode benchmark...") + + # Initialize tenants + tenant_ids = [] + for i in range(self.config.num_tenants): + tid = self.model.initialize(tenant_id=f"tenant_{i}") + tenant_ids.append(tid) + + self.model.reset_stats() + + # Create scheduler + scheduler = ClockCycleScheduler( + model=self.model, + cycle_interval_ms=self.config.clock_cycle_interval_ms, + ) + scheduler.start() + + # Create clients for each tenant + clients = { + tid: ClockCycleTrainingClient(scheduler, tid) + for tid in tenant_ids + } + + total_steps = 0 + total_samples = 0 + step_latencies = [] + + start_time = time.time() + + # Training loop - all tenants submit concurrently + def tenant_worker(tenant_id: str, client: ClockCycleTrainingClient): + nonlocal total_steps, total_samples + latencies = [] + + # Each tenant has its own batch + x = torch.randn(self.config.batch_size_per_tenant, self.config.hidden_size) + + for step in range(self.config.steps_per_tenant): + step_start = time.time() + + # Submit forward-backward and optimizer step + result = client.train_step(x) + + latencies.append(time.time() - step_start) + total_steps += 1 + total_samples += self.config.batch_size_per_tenant + + return latencies + + # Run all tenants concurrently + with ThreadPoolExecutor(max_workers=self.config.num_tenants) as executor: + futures = { + executor.submit(tenant_worker, tid, clients[tid]): tid + for tid in tenant_ids + } + + for future in as_completed(futures): + try: + latencies = future.result() + step_latencies.extend(latencies) + except Exception as e: + logger.error(f"Tenant worker failed: {e}") + + total_time = time.time() - start_time + + # Stop scheduler + scheduler.stop() + + # Get scheduler stats + sched_stats = scheduler.get_summary_stats() + + # Cleanup + for tid in tenant_ids: + self.model.finalize(tid) + + # Calculate metrics + model_stats = self.model.get_stats() + + return BenchmarkResult( + mode="Tinker (Clock Cycle)", + total_time=total_time, + total_steps=total_steps, + total_samples=total_samples, + throughput_steps=total_steps / total_time, + throughput_samples=total_samples / total_time, + avg_latency=sum(step_latencies) / len(step_latencies) if step_latencies else 0, + base_model_calls=model_stats['base_model_forward_calls'], + base_model_samples=model_stats['base_model_total_samples'], + compute_time=model_stats['compute_time'], + comm_time=model_stats['comm_time'], + gpu_utilization=sched_stats.get('gpu_utilization', 0), + ) + + +# ============ Test Functions ============ + +def test_twinkle_mode(): + """Test Twinkle mode functionality.""" + logger.info("Testing Twinkle mode...") + + model = MockMultiTenantModel(base_model_ms=1.0, lora_ms=0.5, comm_ms=0.5) + + # Initialize 2 tenants + tid1 = model.initialize(tenant_id="test_1") + tid2 = model.initialize(tenant_id="test_2") + + assert model.tenant_count() == 2 + assert model.has_tenant(tid1) + assert model.has_tenant(tid2) + + # Training step for tenant 1 + x = torch.randn(4, 256) + with model.scope(tid1): + model.zero_grad(tid1) + output = model(x) + loss = output.mean() + model.backward(loss, tid1) + model.finish_grad_sync(tid1) + model.step(tid1) + + # Training step for tenant 2 + with model.scope(tid2): + model.zero_grad(tid2) + output = model(x) + loss = output.mean() + model.backward(loss, tid2) + model.finish_grad_sync(tid2) + model.step(tid2) + + # Verify base model was called twice + stats = model.get_stats() + assert stats['base_model_forward_calls'] == 2, f"Expected 2 calls, got {stats['base_model_forward_calls']}" + + # Cleanup + model.finalize(tid1) + model.finalize(tid2) + + assert model.tenant_count() == 0 + + logger.info("Twinkle mode test PASSED") + return True + + +def test_tinker_mode(): + """Test Tinker mode functionality.""" + from twinkle.megatron.distributed.clock_cycle_scheduler import ( + ClockCycleScheduler, + ClockCycleTrainingClient, + ) + + logger.info("Testing Tinker mode...") + + model = MockMultiTenantModel(base_model_ms=1.0, lora_ms=0.5, comm_ms=0.5) + + # Initialize 2 tenants + tid1 = model.initialize(tenant_id="test_1") + tid2 = model.initialize(tenant_id="test_2") + + # Create scheduler + scheduler = ClockCycleScheduler(model, cycle_interval_ms=50.0) + scheduler.start() + + # Create clients + client1 = ClockCycleTrainingClient(scheduler, tid1) + client2 = ClockCycleTrainingClient(scheduler, tid2) + + x1 = torch.randn(4, 256) + x2 = torch.randn(4, 256) + + # Both tenants submit requests (should be in same cycle) + future1 = client1.forward_backward(x1) + future2 = client2.forward_backward(x2) + + opt1 = client1.optim_step() + opt2 = client2.optim_step() + + # Wait for results + result1 = future1.result(timeout=5.0) + result2 = future2.result(timeout=5.0) + opt1.result(timeout=5.0) + opt2.result(timeout=5.0) + + assert 'loss' in result1 or 'error' in result1, f"Unexpected result: {result1}" + assert 'loss' in result2 or 'error' in result2, f"Unexpected result: {result2}" + + # Check they were in same cycle + if 'cycle_id' in result1 and 'cycle_id' in result2: + logger.info(f"Cycle IDs: {result1['cycle_id']}, {result2['cycle_id']}") + + # Stop scheduler + scheduler.stop() + + # Check stats + stats = scheduler.get_summary_stats() + logger.info(f"Scheduler stats: {stats}") + + # Cleanup + model.finalize(tid1) + model.finalize(tid2) + + logger.info("Tinker mode test PASSED") + return True + + +def test_batch_merging(): + """Test that batch merging works correctly.""" + from twinkle.megatron.distributed.clock_cycle_scheduler import BatchBuilder, TrainingRequest, RequestType + + logger.info("Testing batch merging...") + + builder = BatchBuilder() + + # Create requests from 3 tenants + requests = { + 'tenant_a': TrainingRequest( + tenant_id='tenant_a', + request_type=RequestType.FORWARD_BACKWARD, + inputs=torch.randn(4, 256), + ), + 'tenant_b': TrainingRequest( + tenant_id='tenant_b', + request_type=RequestType.FORWARD_BACKWARD, + inputs=torch.randn(8, 256), + ), + 'tenant_c': TrainingRequest( + tenant_id='tenant_c', + request_type=RequestType.FORWARD_BACKWARD, + inputs=torch.randn(2, 256), + ), + } + + # Build merged batch + merged = builder.build(requests) + + # Verify + assert merged.total_size == 14, f"Expected 14, got {merged.total_size}" + assert merged.merged_inputs.shape == (14, 256), f"Wrong shape: {merged.merged_inputs.shape}" + assert merged.tenant_slices['tenant_a'] == (0, 4) + assert merged.tenant_slices['tenant_b'] == (4, 12) + assert merged.tenant_slices['tenant_c'] == (12, 14) + + logger.info("Batch merging test PASSED") + return True + + +def run_benchmark_comparison(config: BenchmarkConfig): + """Run and compare both benchmarks.""" + print("") + print("=" * 60) + print("Multi-Tenant Architecture Benchmark") + print("=" * 60) + print(f"Config: {config}") + print("") + + # Run Twinkle benchmark + twinkle = TwinkleBenchmark(config) + twinkle_result = twinkle.run() + + # Run Tinker benchmark + tinker = TinkerBenchmark(config) + tinker_result = tinker.run() + + # Print results + print("") + print("=" * 60) + print("Results") + print("=" * 60) + print(twinkle_result) + print(tinker_result) + + # Comparison + print("=" * 60) + print("Comparison") + print("=" * 60) + + # Throughput + speedup = tinker_result.throughput_samples / twinkle_result.throughput_samples + print(f"Throughput speedup (Tinker/Twinkle): {speedup:.2f}x") + + # Base model efficiency + base_model_ratio = twinkle_result.base_model_calls / max(tinker_result.base_model_calls, 1) + print(f"Base model calls: {twinkle_result.base_model_calls} vs {tinker_result.base_model_calls} ({base_model_ratio:.1f}x fewer)") + + # Latency + latency_diff = (twinkle_result.avg_latency - tinker_result.avg_latency) / twinkle_result.avg_latency * 100 + print(f"Latency improvement: {latency_diff:.1f}%") + + # GPU utilization + gpu_diff = (tinker_result.gpu_utilization - twinkle_result.gpu_utilization) * 100 + print(f"GPU utilization difference: {gpu_diff:+.1f}%") + + return twinkle_result, tinker_result + + +def main(): + parser = argparse.ArgumentParser(description="Multi-tenant benchmark") + parser.add_argument("--num-tenants", type=int, default=4) + parser.add_argument("--steps", type=int, default=10) + parser.add_argument("--batch-size", type=int, default=8) + parser.add_argument("--base-model-ms", type=float, default=10.0) + parser.add_argument("--lora-ms", type=float, default=2.0) + parser.add_argument("--comm-ms", type=float, default=5.0) + parser.add_argument("--cycle-ms", type=float, default=50.0) + parser.add_argument("--test-only", action="store_true", help="Run tests only") + args = parser.parse_args() + + if args.test_only: + test_twinkle_mode() + test_batch_merging() + test_tinker_mode() + logger.info("All tests passed!") + return + + config = BenchmarkConfig( + num_tenants=args.num_tenants, + steps_per_tenant=args.steps, + batch_size_per_tenant=args.batch_size, + base_model_ms=args.base_model_ms, + lora_ms=args.lora_ms, + comm_ms=args.comm_ms, + clock_cycle_interval_ms=args.cycle_ms, + ) + + run_benchmark_comparison(config) + + +if __name__ == "__main__": + main() diff --git a/tests/megatron/test_multi_tenant_ddp.py b/tests/megatron/test_multi_tenant_ddp.py deleted file mode 100644 index 3056b3ee..00000000 --- a/tests/megatron/test_multi_tenant_ddp.py +++ /dev/null @@ -1,181 +0,0 @@ -""" -Unit tests for Multi-Tenant LoRA DDP. - -Tests: -1. Tenant context (ContextVar) -2. Tenant manager lifecycle -3. Dynamic tenant add/remove -""" - -import threading -import unittest -from unittest.mock import MagicMock, patch - -import torch -import torch.nn as nn - - -class TestTenantContext(unittest.TestCase): - """Tests for tenant_context module.""" - - def setUp(self): - from twinkle.megatron.distributed.tenant_context import set_current_tenant - set_current_tenant(None) - - def test_get_set(self): - from twinkle.megatron.distributed.tenant_context import ( - get_current_tenant, set_current_tenant - ) - - self.assertIsNone(get_current_tenant()) - set_current_tenant("a") - self.assertEqual(get_current_tenant(), "a") - set_current_tenant(None) - self.assertIsNone(get_current_tenant()) - - def test_scope(self): - from twinkle.megatron.distributed.tenant_context import ( - get_current_tenant, tenant_scope - ) - - with tenant_scope("x"): - self.assertEqual(get_current_tenant(), "x") - with tenant_scope("y"): - self.assertEqual(get_current_tenant(), "y") - self.assertEqual(get_current_tenant(), "x") - self.assertIsNone(get_current_tenant()) - - def test_require_tenant(self): - from twinkle.megatron.distributed.tenant_context import ( - require_tenant, tenant_scope - ) - - with self.assertRaises(RuntimeError): - require_tenant() - - with tenant_scope("t"): - self.assertEqual(require_tenant(), "t") - - def test_thread_isolation(self): - from twinkle.megatron.distributed.tenant_context import ( - get_current_tenant, set_current_tenant - ) - - results = {} - - def worker(tid): - set_current_tenant(tid) - import time - time.sleep(0.01) - results[tid] = get_current_tenant() - - threads = [threading.Thread(target=worker, args=(f"t{i}",)) for i in range(5)] - for t in threads: - t.start() - for t in threads: - t.join() - - for i in range(5): - self.assertEqual(results[f"t{i}"], f"t{i}") - - def test_generate_id(self): - from twinkle.megatron.distributed.tenant_context import generate_tenant_id - - ids = [generate_tenant_id() for _ in range(100)] - self.assertEqual(len(ids), len(set(ids))) - - -class TestTenantManager(unittest.TestCase): - """Tests for TenantManager.""" - - def test_initialize_finalize(self): - from twinkle.megatron.distributed.tenant_manager import TenantManager - - model = nn.Linear(10, 10) - manager = TenantManager(model) - - # Mock PEFT - with patch('twinkle.megatron.distributed.tenant_manager.PEFT_AVAILABLE', False): - # Add fake lora param - lora_param = nn.Parameter(torch.randn(4, 10)) - lora_param.requires_grad = True - model.lora_A = nn.ParameterDict({'test': lora_param}) - - # Need to patch named_parameters - original_named_params = model.named_parameters - def mock_named_params(): - yield 'weight', model.weight - yield 'lora_A.test.lora_A', lora_param - model.named_parameters = mock_named_params - - tid = manager.initialize( - optimizer_kwargs={'lr': 1e-4}, - adapter_name='test', - ) - - self.assertTrue(manager.has(tid)) - self.assertIn(tid, manager.list()) - - state = manager.get(tid) - self.assertEqual(state.adapter_name, 'test') - - manager.finalize(tid) - self.assertFalse(manager.has(tid)) - - def test_callbacks(self): - from twinkle.megatron.distributed.tenant_manager import TenantManager - - model = nn.Linear(10, 10) - manager = TenantManager(model) - - added = [] - removed = [] - - manager.register_add_callback(lambda s: added.append(s.tenant_id)) - manager.register_remove_callback(lambda s: removed.append(s.tenant_id)) - - with patch('twinkle.megatron.distributed.tenant_manager.PEFT_AVAILABLE', False): - lora_param = nn.Parameter(torch.randn(4, 10)) - original_named_params = model.named_parameters - def mock_named_params(): - yield 'lora_A.test.lora_A', lora_param - model.named_parameters = mock_named_params - - tid = manager.initialize(adapter_name='test') - self.assertEqual(added, [tid]) - - manager.finalize(tid) - self.assertEqual(removed, [tid]) - - -class TestMultiTenantDDP(unittest.TestCase): - """Tests for MultiTenantLoRADDP.""" - - @patch('twinkle.megatron.distributed.multi_tenant_ddp.MEGATRON_AVAILABLE', False) - def test_requires_megatron(self): - from twinkle.megatron.distributed.multi_tenant_ddp import MultiTenantLoRADDP - - with self.assertRaises(ImportError): - MultiTenantLoRADDP( - config=MagicMock(), - ddp_config=MagicMock(), - module=nn.Linear(10, 10), - ) - - -class TestMegatronMultiAdapter(unittest.TestCase): - """Tests for MegatronMultiAdapter.""" - - def test_adapter_var(self): - from twinkle.megatron.model.multi_tenant_megatron import MegatronMultiAdapter - - MegatronMultiAdapter._patched = False - - self.assertIsNone(MegatronMultiAdapter.get_current_adapter_name()) - MegatronMultiAdapter.set_current_adapter_name("a") - self.assertEqual(MegatronMultiAdapter.get_current_adapter_name(), "a") - MegatronMultiAdapter.set_current_adapter_name(None) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/megatron/test_multi_tenant_modules.py b/tests/megatron/test_multi_tenant_modules.py new file mode 100644 index 00000000..b2c3b09e --- /dev/null +++ b/tests/megatron/test_multi_tenant_modules.py @@ -0,0 +1,641 @@ +#!/usr/bin/env python +""" +Unit tests for multi-tenant modules. + +Tests: +1. tenant_context.py - ContextVar-based tenant management +2. tenant_manager.py - Tenant lifecycle management (WIP, requires PEFT) +3. multi_tenant_ddp.py - Multi-tenant DDP (requires Megatron) +4. clock_cycle_scheduler.py - Clock cycle scheduling + +Each module is tested independently with mock dependencies where needed. +""" + +import logging +import threading +import time +import unittest +from concurrent.futures import ThreadPoolExecutor, as_completed +from contextlib import contextmanager +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional +from unittest.mock import MagicMock, patch + +import torch +import torch.nn as nn + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Test 1: tenant_context.py +# ============================================================================ + +class TestTenantContext(unittest.TestCase): + """Tests for tenant_context module.""" + + def setUp(self): + """Reset tenant context before each test.""" + from twinkle.megatron.distributed.tenant_context import set_current_tenant + set_current_tenant(None) + + def test_get_set_tenant(self): + """Test basic get/set tenant.""" + from twinkle.megatron.distributed.tenant_context import ( + get_current_tenant, set_current_tenant + ) + + # Initially None + self.assertIsNone(get_current_tenant()) + + # Set tenant + set_current_tenant('tenant_a') + self.assertEqual(get_current_tenant(), 'tenant_a') + + # Change tenant + set_current_tenant('tenant_b') + self.assertEqual(get_current_tenant(), 'tenant_b') + + # Clear tenant + set_current_tenant(None) + self.assertIsNone(get_current_tenant()) + + def test_require_tenant(self): + """Test require_tenant raises error when not set.""" + from twinkle.megatron.distributed.tenant_context import ( + require_tenant, set_current_tenant + ) + + # Should raise when not set + with self.assertRaises(RuntimeError): + require_tenant() + + # Should return when set + set_current_tenant('tenant_a') + self.assertEqual(require_tenant(), 'tenant_a') + + def test_tenant_scope(self): + """Test tenant_scope context manager.""" + from twinkle.megatron.distributed.tenant_context import ( + get_current_tenant, tenant_scope, set_current_tenant + ) + + set_current_tenant('outer') + + with tenant_scope('inner'): + self.assertEqual(get_current_tenant(), 'inner') + + # Should restore after context + self.assertEqual(get_current_tenant(), 'outer') + + def test_nested_scopes(self): + """Test nested tenant scopes.""" + from twinkle.megatron.distributed.tenant_context import ( + get_current_tenant, tenant_scope + ) + + with tenant_scope('a'): + self.assertEqual(get_current_tenant(), 'a') + + with tenant_scope('b'): + self.assertEqual(get_current_tenant(), 'b') + + with tenant_scope('c'): + self.assertEqual(get_current_tenant(), 'c') + + self.assertEqual(get_current_tenant(), 'b') + + self.assertEqual(get_current_tenant(), 'a') + + def test_thread_isolation(self): + """Test that tenant context is isolated between threads.""" + from twinkle.megatron.distributed.tenant_context import ( + get_current_tenant, set_current_tenant + ) + + results = {} + + def thread_func(tenant_id: str, delay: float): + set_current_tenant(tenant_id) + time.sleep(delay) + results[tenant_id] = get_current_tenant() + + # Run multiple threads + threads = [ + threading.Thread(target=thread_func, args=('thread_a', 0.1)), + threading.Thread(target=thread_func, args=('thread_b', 0.05)), + threading.Thread(target=thread_func, args=('thread_c', 0.15)), + ] + + for t in threads: + t.start() + for t in threads: + t.join() + + # Each thread should have its own context + self.assertEqual(results['thread_a'], 'thread_a') + self.assertEqual(results['thread_b'], 'thread_b') + self.assertEqual(results['thread_c'], 'thread_c') + + def test_generate_tenant_id(self): + """Test tenant ID generation.""" + from twinkle.megatron.distributed.tenant_context import generate_tenant_id + + id1 = generate_tenant_id() + id2 = generate_tenant_id() + + # Should be unique + self.assertNotEqual(id1, id2) + + # Should be 8 chars + self.assertEqual(len(id1), 8) + self.assertEqual(len(id2), 8) + + def test_with_tenant_context_decorator(self): + """Test @with_tenant_context decorator.""" + from twinkle.megatron.distributed.tenant_context import ( + with_tenant_context, tenant_scope + ) + + @with_tenant_context + def example_func(tenant_id: Optional[str] = None): + return tenant_id + + # Should use context when tenant_id not provided + with tenant_scope('context_tenant'): + result = example_func() + self.assertEqual(result, 'context_tenant') + + # Should use explicit tenant_id when provided + with tenant_scope('context_tenant'): + result = example_func(tenant_id='explicit_tenant') + self.assertEqual(result, 'explicit_tenant') + + +# ============================================================================ +# Test 2: clock_cycle_scheduler.py +# ============================================================================ + +class MockMultiTenantModel(nn.Module): + """Mock model that implements the required interface for ClockCycleScheduler.""" + + def __init__(self, hidden_size: int = 64, simulate_ms: float = 1.0): + super().__init__() + self.hidden_size = hidden_size + self.simulate_ms = simulate_ms + + # Base layer (frozen) + self.base = nn.Linear(hidden_size, hidden_size) + self.base.weight.requires_grad = False + + # Per-tenant adapters + self._adapters: Dict[str, nn.Module] = {} + self._optimizers: Dict[str, torch.optim.Optimizer] = {} + self._current_tenant: Optional[str] = None + + def add_tenant(self, tenant_id: str) -> None: + """Add a tenant with LoRA adapter.""" + adapter = nn.Linear(self.hidden_size, self.hidden_size) + self._adapters[tenant_id] = adapter + self._optimizers[tenant_id] = torch.optim.SGD(adapter.parameters(), lr=0.01) + + def remove_tenant(self, tenant_id: str) -> None: + """Remove a tenant.""" + if tenant_id in self._adapters: + del self._adapters[tenant_id] + del self._optimizers[tenant_id] + + @contextmanager + def scope(self, tenant_id: str): + """Context manager for tenant scope.""" + old = self._current_tenant + self._current_tenant = tenant_id + try: + yield + finally: + self._current_tenant = old + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass using current tenant's adapter.""" + time.sleep(self.simulate_ms / 1000.0) + + out = self.base(x) + if self._current_tenant and self._current_tenant in self._adapters: + out = out + self._adapters[self._current_tenant](x) + return out + + def zero_grad(self, tenant_id: str) -> None: + """Zero gradients for tenant.""" + if tenant_id in self._optimizers: + self._optimizers[tenant_id].zero_grad(set_to_none=True) + + def step(self, tenant_id: str) -> None: + """Optimizer step for tenant.""" + if tenant_id in self._optimizers: + self._optimizers[tenant_id].step() + + def clip_grad_norm(self, tenant_id: str, max_norm: float = 1.0) -> None: + """Clip gradients for tenant.""" + if tenant_id in self._adapters: + torch.nn.utils.clip_grad_norm_( + self._adapters[tenant_id].parameters(), max_norm + ) + + def finish_grad_sync(self, tenant_id: str) -> None: + """Gradient sync for single tenant (no-op in non-distributed).""" + pass + + def finish_grad_sync_batched(self, tenant_ids: List[str]) -> None: + """Batched gradient sync (no-op in non-distributed).""" + pass + + +class TestClockCycleScheduler(unittest.TestCase): + """Tests for clock_cycle_scheduler module.""" + + def test_model_interface_validation(self): + """Test that scheduler validates model interface.""" + from twinkle.megatron.distributed.clock_cycle_scheduler import ( + ClockCycleScheduler, ModelInterfaceError + ) + + # Model without required methods should fail + bad_model = nn.Linear(10, 10) + with self.assertRaises(ModelInterfaceError): + ClockCycleScheduler(bad_model) + + # Good model should work + good_model = MockMultiTenantModel() + scheduler = ClockCycleScheduler(good_model, cycle_interval_ms=10) + self.assertIsNotNone(scheduler) + + def test_basic_training_step(self): + """Test basic training step through scheduler.""" + from twinkle.megatron.distributed.clock_cycle_scheduler import ( + ClockCycleScheduler, ClockCycleTrainingClient + ) + + model = MockMultiTenantModel(simulate_ms=0.5) + model.add_tenant('tenant_a') + + scheduler = ClockCycleScheduler(model, cycle_interval_ms=20) + scheduler.start() + + try: + client = ClockCycleTrainingClient(scheduler, 'tenant_a') + + x = torch.randn(4, 64) + result = client.train_step(x) + + self.assertIn('loss', result) + self.assertIn('cycle_id', result) + self.assertEqual(result['batch_size'], 4) + + finally: + scheduler.stop() + model.remove_tenant('tenant_a') + + def test_multi_tenant_concurrent(self): + """Test multiple tenants submitting concurrently.""" + from twinkle.megatron.distributed.clock_cycle_scheduler import ( + ClockCycleScheduler, ClockCycleTrainingClient + ) + + model = MockMultiTenantModel(simulate_ms=0.5) + model.add_tenant('tenant_a') + model.add_tenant('tenant_b') + model.add_tenant('tenant_c') + + scheduler = ClockCycleScheduler(model, cycle_interval_ms=50) + scheduler.start() + + try: + clients = { + tid: ClockCycleTrainingClient(scheduler, tid) + for tid in ['tenant_a', 'tenant_b', 'tenant_c'] + } + + # Submit from multiple threads + results = {} + + def worker(tenant_id: str, client: ClockCycleTrainingClient): + x = torch.randn(4, 64) + return client.train_step(x) + + with ThreadPoolExecutor(max_workers=3) as executor: + futures = { + executor.submit(worker, tid, clients[tid]): tid + for tid in clients + } + for future in as_completed(futures): + tid = futures[future] + results[tid] = future.result() + + # All should succeed + for tid, result in results.items(): + self.assertIn('loss', result) + self.assertIn('cycle_id', result) + + # Check stats + stats = scheduler.get_summary_stats() + self.assertGreater(stats['total_cycles'], 0) + self.assertEqual(stats['total_samples'], 12) # 3 tenants * 4 samples + + finally: + scheduler.stop() + for tid in ['tenant_a', 'tenant_b', 'tenant_c']: + model.remove_tenant(tid) + + def test_gradient_isolation(self): + """Test that gradients are isolated between tenants.""" + from twinkle.megatron.distributed.clock_cycle_scheduler import ( + ClockCycleScheduler, ClockCycleTrainingClient + ) + + model = MockMultiTenantModel(simulate_ms=0.5) + model.add_tenant('tenant_a') + model.add_tenant('tenant_b') + + # Get initial weights + weight_a_before = model._adapters['tenant_a'].weight.clone() + weight_b_before = model._adapters['tenant_b'].weight.clone() + + scheduler = ClockCycleScheduler(model, cycle_interval_ms=20) + scheduler.start() + + try: + # Only tenant_a trains + client_a = ClockCycleTrainingClient(scheduler, 'tenant_a') + x = torch.randn(4, 64) + client_a.train_step(x) + + # tenant_a weights should change + weight_a_after = model._adapters['tenant_a'].weight + self.assertFalse(torch.allclose(weight_a_before, weight_a_after)) + + # tenant_b weights should NOT change + weight_b_after = model._adapters['tenant_b'].weight + self.assertTrue(torch.allclose(weight_b_before, weight_b_after)) + + finally: + scheduler.stop() + model.remove_tenant('tenant_a') + model.remove_tenant('tenant_b') + + def test_error_handling(self): + """Test error handling for failed requests.""" + from twinkle.megatron.distributed.clock_cycle_scheduler import ( + ClockCycleScheduler + ) + + # Create a model that raises error on forward for unknown tenant + class FailingModel(MockMultiTenantModel): + def forward(self, x): + if self._current_tenant not in self._adapters: + raise KeyError(f"Tenant '{self._current_tenant}' not found") + return super().forward(x) + + model = FailingModel() + # Don't add any tenants - requests should fail + + scheduler = ClockCycleScheduler(model, cycle_interval_ms=20) + scheduler.start() + + try: + # Submit request for non-existent tenant + future = scheduler.submit_forward_backward('nonexistent', torch.randn(4, 64)) + + # Should raise exception + with self.assertRaises(Exception): + future.result(timeout=5.0) + + finally: + scheduler.stop() + + def test_cycle_stats(self): + """Test cycle statistics collection.""" + from twinkle.megatron.distributed.clock_cycle_scheduler import ( + ClockCycleScheduler, ClockCycleTrainingClient + ) + + model = MockMultiTenantModel(simulate_ms=1.0) + model.add_tenant('tenant_a') + + scheduler = ClockCycleScheduler(model, cycle_interval_ms=50) + scheduler.start() + + try: + client = ClockCycleTrainingClient(scheduler, 'tenant_a') + + # Run multiple steps + for _ in range(3): + x = torch.randn(4, 64) + client.train_step(x) + + # Get stats + stats_list = scheduler.get_stats() + summary = scheduler.get_summary_stats() + + self.assertEqual(len(stats_list), 3) + self.assertEqual(summary['total_cycles'], 3) + self.assertEqual(summary['total_samples'], 12) + + # Check individual stats + for stat in stats_list: + self.assertGreater(stat.forward_time, 0) + self.assertGreater(stat.duration, 0) + + finally: + scheduler.stop() + model.remove_tenant('tenant_a') + + +# ============================================================================ +# Test 3: multi_tenant_ddp.py (Mock test - requires Megatron) +# ============================================================================ + +class TestMultiTenantDDP(unittest.TestCase): + """Tests for multi_tenant_ddp module (mocked).""" + + def test_tenant_ddp_state_dataclass(self): + """Test TenantDDPState dataclass.""" + from twinkle.megatron.distributed.multi_tenant_ddp import TenantDDPState + + state = TenantDDPState(tenant_id='test_tenant') + + self.assertEqual(state.tenant_id, 'test_tenant') + self.assertEqual(state.params, []) + self.assertEqual(state.buffers, []) + self.assertEqual(state.bucket_groups, []) + self.assertIsNone(state.process_group) + + @unittest.skipUnless( + False, # Skip by default - requires Megatron + "Requires Megatron-Core" + ) + def test_multi_tenant_lora_ddp_creation(self): + """Test MultiTenantLoRADDP creation (requires Megatron).""" + pass + + def test_requires_megatron(self): + """Test that MultiTenantLoRADDP requires Megatron.""" + from unittest.mock import MagicMock, patch + + with patch('twinkle.megatron.distributed.multi_tenant_ddp.MEGATRON_AVAILABLE', False): + from twinkle.megatron.distributed.multi_tenant_ddp import MultiTenantLoRADDP + + with self.assertRaises(ImportError): + MultiTenantLoRADDP( + config=MagicMock(), + ddp_config=MagicMock(), + module=nn.Linear(10, 10), + ) + + +# ============================================================================ +# Test 4: MegatronMultiAdapter +# ============================================================================ + +class TestMegatronMultiAdapter(unittest.TestCase): + """Tests for MegatronMultiAdapter.""" + + def test_adapter_context_var(self): + """Test adapter name ContextVar management.""" + from twinkle.megatron.model.multi_tenant_megatron import MegatronMultiAdapter + + # Reset state + MegatronMultiAdapter._patched = False + + # Test get/set + self.assertIsNone(MegatronMultiAdapter.get_current_adapter_name()) + MegatronMultiAdapter.set_current_adapter_name("adapter_a") + self.assertEqual(MegatronMultiAdapter.get_current_adapter_name(), "adapter_a") + MegatronMultiAdapter.set_current_adapter_name(None) + self.assertIsNone(MegatronMultiAdapter.get_current_adapter_name()) + + +# ============================================================================ +# Test 5: Integration test - tenant_context + clock_cycle_scheduler +# ============================================================================ + +class TestIntegration(unittest.TestCase): + """Integration tests combining multiple modules.""" + + def test_context_with_scheduler(self): + """Test that tenant_context works with scheduler.""" + from twinkle.megatron.distributed.tenant_context import ( + get_current_tenant, tenant_scope, set_current_tenant + ) + from twinkle.megatron.distributed.clock_cycle_scheduler import ( + ClockCycleScheduler, ClockCycleTrainingClient + ) + + model = MockMultiTenantModel(simulate_ms=0.5) + model.add_tenant('tenant_a') + model.add_tenant('tenant_b') + + scheduler = ClockCycleScheduler(model, cycle_interval_ms=30) + scheduler.start() + + try: + # Test that context propagates correctly + with tenant_scope('tenant_a'): + self.assertEqual(get_current_tenant(), 'tenant_a') + + client = ClockCycleTrainingClient(scheduler, 'tenant_a') + x = torch.randn(4, 64) + result = client.train_step(x) + + self.assertIn('loss', result) + + # Context should be cleared outside + set_current_tenant(None) + self.assertIsNone(get_current_tenant()) + + finally: + scheduler.stop() + model.remove_tenant('tenant_a') + model.remove_tenant('tenant_b') + + def test_multi_threaded_with_context(self): + """Test multi-threaded training with tenant context.""" + from twinkle.megatron.distributed.tenant_context import ( + get_current_tenant, tenant_scope + ) + from twinkle.megatron.distributed.clock_cycle_scheduler import ( + ClockCycleScheduler, ClockCycleTrainingClient + ) + + model = MockMultiTenantModel(simulate_ms=0.5) + for i in range(4): + model.add_tenant(f'tenant_{i}') + + scheduler = ClockCycleScheduler(model, cycle_interval_ms=50) + scheduler.start() + + results = {} + errors = [] + + def worker(tenant_id: str): + try: + with tenant_scope(tenant_id): + # Verify context is correct + if get_current_tenant() != tenant_id: + errors.append(f"Context mismatch for {tenant_id}") + return + + client = ClockCycleTrainingClient(scheduler, tenant_id) + x = torch.randn(4, 64) + result = client.train_step(x) + results[tenant_id] = result + except Exception as e: + errors.append(str(e)) + + try: + with ThreadPoolExecutor(max_workers=4) as executor: + futures = [ + executor.submit(worker, f'tenant_{i}') + for i in range(4) + ] + for f in futures: + f.result() + + self.assertEqual(len(errors), 0, f"Errors: {errors}") + self.assertEqual(len(results), 4) + + for tid, result in results.items(): + self.assertIn('loss', result) + + finally: + scheduler.stop() + for i in range(4): + model.remove_tenant(f'tenant_{i}') + + +# ============================================================================ +# Main +# ============================================================================ + +def run_tests(): + """Run all tests.""" + loader = unittest.TestLoader() + suite = unittest.TestSuite() + + # Add test cases + suite.addTests(loader.loadTestsFromTestCase(TestTenantContext)) + suite.addTests(loader.loadTestsFromTestCase(TestClockCycleScheduler)) + suite.addTests(loader.loadTestsFromTestCase(TestMultiTenantDDP)) + suite.addTests(loader.loadTestsFromTestCase(TestMegatronMultiAdapter)) + suite.addTests(loader.loadTestsFromTestCase(TestIntegration)) + + # Run + runner = unittest.TextTestRunner(verbosity=2) + result = runner.run(suite) + + return result.wasSuccessful() + + +if __name__ == '__main__': + success = run_tests() + exit(0 if success else 1) From 059c0f30ed2b9adc1cfc90a70b786a5be28dc4df Mon Sep 17 00:00:00 2001 From: root Date: Mon, 19 Jan 2026 14:50:08 +0800 Subject: [PATCH 10/22] fix --- .gitignore | 1 + src/twinkle/utils/parallel.py | 27 ++------------------------- 2 files changed, 3 insertions(+), 25 deletions(-) diff --git a/.gitignore b/.gitignore index de8aa916..a380c1a8 100644 --- a/.gitignore +++ b/.gitignore @@ -35,6 +35,7 @@ wheels/ /package /temp MANIFEST +.locks/ # PyInstaller # Usually these files are written by a python script from a template diff --git a/src/twinkle/utils/parallel.py b/src/twinkle/utils/parallel.py index 619fcb29..d3d7a384 100644 --- a/src/twinkle/utils/parallel.py +++ b/src/twinkle/utils/parallel.py @@ -8,17 +8,6 @@ shutil.rmtree('.locks', ignore_errors=True) os.makedirs('.locks', exist_ok=True) -def acquire_lock(lock, blocking): - try: - lock.acquire(blocking=blocking) - return True - except Exception: - return False - - -def release_lock(lock): - lock.release() - def acquire_lock(lock, blocking): try: @@ -33,20 +22,8 @@ def release_lock(lock): @contextmanager -def processing_lock(lock_file: str, timeout: float = 600.0): - """Acquire a file lock for distributed-safe processing. - - Args: - lock_file: Name of the lock file (will be sanitized). - timeout: Maximum time to wait for lock acquisition in seconds. - - In distributed training, only rank 0 should process data while - other ranks wait. This lock ensures that. - """ - # Sanitize lock file name - safe_name = lock_file.replace('/', '_').replace(':', '_').replace(' ', '_') - lock_path = os.path.join(_locks_dir, f"{safe_name}.lock") - lock = FileLock(lock_path, timeout=timeout) +def processing_lock(lock_file: str): + lock = FileLock(os.path.join('.locks', f"{lock_file}.lock")) if acquire_lock(lock, False): try: From 60f9774977ebcb84a295979552489c91c784ec85 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 19 Jan 2026 15:35:18 +0800 Subject: [PATCH 11/22] Remove untested multi-tenant code from tracking Multi-tenant LoRA training code has not been fully tested. Files removed from git tracking (kept locally): - src/twinkle/megatron/distributed/ - src/twinkle/megatron/model/multi_tenant_megatron.py - cookbook/megatron/megatron_multi_tenant/ - tests/megatron/test_multi_tenant_*.py --- .../megatron_multi_tenant/__init__.py | 5 - .../megatron/megatron_multi_tenant/client.py | 164 ---- .../megatron/megatron_multi_tenant/server.py | 242 ------ src/twinkle/megatron/distributed/__init__.py | 49 -- .../distributed/clock_cycle_scheduler.py | 598 -------------- .../megatron/distributed/multi_tenant_ddp.py | 398 ---------- .../megatron/distributed/tenant_context.py | 106 --- .../megatron/distributed/tenant_manager.py | 270 ------- .../megatron/model/multi_tenant_megatron.py | 333 -------- tests/megatron/test_multi_tenant_benchmark.py | 746 ------------------ tests/megatron/test_multi_tenant_modules.py | 641 --------------- 11 files changed, 3552 deletions(-) delete mode 100644 cookbook/megatron/megatron_multi_tenant/__init__.py delete mode 100644 cookbook/megatron/megatron_multi_tenant/client.py delete mode 100644 cookbook/megatron/megatron_multi_tenant/server.py delete mode 100644 src/twinkle/megatron/distributed/__init__.py delete mode 100644 src/twinkle/megatron/distributed/clock_cycle_scheduler.py delete mode 100644 src/twinkle/megatron/distributed/multi_tenant_ddp.py delete mode 100644 src/twinkle/megatron/distributed/tenant_context.py delete mode 100644 src/twinkle/megatron/distributed/tenant_manager.py delete mode 100644 src/twinkle/megatron/model/multi_tenant_megatron.py delete mode 100644 tests/megatron/test_multi_tenant_benchmark.py delete mode 100644 tests/megatron/test_multi_tenant_modules.py diff --git a/cookbook/megatron/megatron_multi_tenant/__init__.py b/cookbook/megatron/megatron_multi_tenant/__init__.py deleted file mode 100644 index bda8411f..00000000 --- a/cookbook/megatron/megatron_multi_tenant/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Multi-Tenant Megatron LoRA Training Demo -# -# This directory contains demo code for multi-tenant training: -# - server.py: FastAPI server managing shared base model -# - client.py: Training client for remote training diff --git a/cookbook/megatron/megatron_multi_tenant/client.py b/cookbook/megatron/megatron_multi_tenant/client.py deleted file mode 100644 index 3772a414..00000000 --- a/cookbook/megatron/megatron_multi_tenant/client.py +++ /dev/null @@ -1,164 +0,0 @@ -""" -Multi-Tenant Megatron LoRA Training - Client Example. - -Simple training loop using remote multi-tenant server. - -Usage: - python client.py --server-url http://localhost:8080 -""" - -import logging -import time -from dataclasses import dataclass -from typing import Any, Dict, Iterator, Optional - -import requests - -logger = logging.getLogger(__name__) - - -@dataclass -class Config: - """Training configuration.""" - server_url: str = "http://localhost:8080" - lora_rank: int = 8 - learning_rate: float = 1e-4 - batch_size: int = 8 - gradient_accumulation_steps: int = 4 - max_grad_norm: float = 1.0 - log_every: int = 10 - - -class TrainingClient: - """ - Simple client for multi-tenant LoRA training. - - Example: - >>> client = TrainingClient(server_url) - >>> client.initialize(lora_rank=8, learning_rate=1e-4) - >>> - >>> for batch in dataloader: - ... result = client.forward_backward(batch) - ... if client.should_step(): - ... client.step() - >>> - >>> client.finalize() - """ - - def __init__(self, server_url: str = "http://localhost:8080"): - self.server_url = server_url.rstrip('/') - self.tenant_id: Optional[str] = None - self._session = requests.Session() - self._accumulated = 0 - self._ga_steps = 1 - - def _post(self, endpoint: str, **kwargs) -> Dict: - """Make POST request.""" - headers = {"X-Tenant-ID": self.tenant_id} if self.tenant_id else {} - resp = self._session.post( - f"{self.server_url}{endpoint}", - headers=headers, - json=kwargs, - timeout=300, - ) - resp.raise_for_status() - return resp.json() - - def initialize( - self, - lora_rank: int = 8, - learning_rate: float = 1e-4, - gradient_accumulation_steps: int = 1, - **kwargs, - ) -> str: - """Initialize tenant on server.""" - result = self._post( - "/initialize", - lora_config={"r": lora_rank, "target_modules": "all-linear"}, - optimizer_kwargs={"lr": learning_rate}, - gradient_accumulation_steps=gradient_accumulation_steps, - **kwargs, - ) - self.tenant_id = result["tenant_id"] - self._ga_steps = gradient_accumulation_steps - logger.info(f"Initialized: {self.tenant_id}") - return self.tenant_id - - def finalize(self): - """Cleanup tenant.""" - if self.tenant_id: - self._post("/finalize") - logger.info(f"Finalized: {self.tenant_id}") - self.tenant_id = None - - def forward_backward(self, inputs: Any) -> Dict: - """Forward + backward pass.""" - result = self._post("/forward_backward", inputs=inputs) - self._accumulated += 1 - return result.get("data", {}) - - def should_step(self) -> bool: - """Check if optimizer step should happen.""" - return self._accumulated >= self._ga_steps - - def step(self): - """Optimizer step.""" - self._post("/finish_grad_sync") - self._post("/clip_grad_norm") - self._post("/step") - self._post("/zero_grad") - self._post("/lr_step") - self._accumulated = 0 - - def __enter__(self): - return self - - def __exit__(self, *args): - self.finalize() - - -def main(config: Config): - """Example training loop.""" - logging.basicConfig(level=logging.INFO) - - # Create client - client = TrainingClient(config.server_url) - - # Initialize - client.initialize( - lora_rank=config.lora_rank, - learning_rate=config.learning_rate, - gradient_accumulation_steps=config.gradient_accumulation_steps, - ) - - try: - # Training loop - for step in range(100): - start = time.time() - - # Create dummy batch (replace with your data loading) - batch = { - "input_ids": list(range(128)), - "attention_mask": [1] * 128, - "labels": list(range(128)), - } - - # Forward + backward - result = client.forward_backward(batch) - - # Optimizer step - if client.should_step(): - client.step() - - if step % config.log_every == 0: - elapsed = time.time() - start - logger.info(f"Step {step}, time: {elapsed:.2f}s") - - logger.info("Training complete!") - - finally: - client.finalize() - - -if __name__ == "__main__": - main(Config()) diff --git a/cookbook/megatron/megatron_multi_tenant/server.py b/cookbook/megatron/megatron_multi_tenant/server.py deleted file mode 100644 index d2ca3c3a..00000000 --- a/cookbook/megatron/megatron_multi_tenant/server.py +++ /dev/null @@ -1,242 +0,0 @@ -""" -Multi-Tenant Megatron LoRA Training - Server. - -Creates a shared base model and provides APIs for multi-tenant training. - -Usage: - python server.py --model-id Qwen/Qwen2.5-7B --tp 2 --port 8080 -""" - -import argparse -import logging -import threading -import time -from typing import Any, Dict, List, Optional - -import torch -from fastapi import FastAPI, HTTPException, Request -from pydantic import BaseModel - -logger = logging.getLogger(__name__) - - -# ============ Request/Response Models ============ - -class InitializeRequest(BaseModel): - lora_config: Optional[Dict[str, Any]] = None - optimizer_cls: str = "AdamW" - optimizer_kwargs: Optional[Dict[str, Any]] = None - gradient_accumulation_steps: int = 1 - max_grad_norm: float = 1.0 - -class InputsRequest(BaseModel): - inputs: Any - -class TenantResponse(BaseModel): - status: str = "ok" - tenant_id: Optional[str] = None - data: Optional[Any] = None - - -# ============ Server ============ - -class MultiTenantServer: - """Server managing multi-tenant Megatron model.""" - - TIMEOUT = 60 * 30 # 30 min heartbeat timeout - - def __init__(self, model_id: str, tp_size: int = 1): - self.model_id = model_id - self.tp_size = tp_size - self.model = None - self._heartbeats: Dict[str, float] = {} - self._lock = threading.Lock() - - def setup(self): - """Initialize model.""" - from twinkle.megatron.model import ( - MultiTenantMegatronModel, - initialize_megatron_model, - ) - - logger.info(f"Loading model: {self.model_id}") - base_model, config = initialize_megatron_model( - model_id=self.model_id, - tensor_parallel_size=self.tp_size, - ) - - # Freeze base model - for p in base_model.parameters(): - p.requires_grad = False - - self.model = MultiTenantMegatronModel(base_model, config) - logger.info("Server ready") - - # Start heartbeat monitor - threading.Thread(target=self._monitor, daemon=True).start() - - def _monitor(self): - """Cleanup inactive tenants.""" - while True: - time.sleep(60) - now = time.time() - with self._lock: - expired = [t for t, ts in self._heartbeats.items() if now - ts > self.TIMEOUT] - for tid in expired: - logger.warning(f"Tenant {tid} timed out") - try: - self.finalize(tid) - except: - pass - - def _heartbeat(self, tenant_id: str): - with self._lock: - self._heartbeats[tenant_id] = time.time() - - def initialize(self, request: InitializeRequest) -> str: - """Initialize tenant.""" - from peft import LoraConfig - - lora_config = None - if request.lora_config: - lora_config = LoraConfig(**request.lora_config) - - opt_map = {"AdamW": torch.optim.AdamW, "Adam": torch.optim.Adam} - opt_cls = opt_map.get(request.optimizer_cls, torch.optim.AdamW) - - tenant_id = self.model.initialize( - lora_config=lora_config, - optimizer_cls=opt_cls, - optimizer_kwargs=request.optimizer_kwargs, - gradient_accumulation_steps=request.gradient_accumulation_steps, - max_grad_norm=request.max_grad_norm, - ) - - self._heartbeat(tenant_id) - return tenant_id - - def finalize(self, tenant_id: str): - """Finalize tenant.""" - self.model.finalize(tenant_id) - with self._lock: - self._heartbeats.pop(tenant_id, None) - - def forward_backward(self, tenant_id: str, inputs: Any) -> Dict: - """Forward + backward.""" - self._heartbeat(tenant_id) - - with self.model.scope(tenant_id): - output = self.model(inputs) - # Compute loss (simplified - real impl would depend on task) - loss = output.mean() if isinstance(output, torch.Tensor) else torch.tensor(0.0) - self.model.backward(loss) - return {"loss": loss.item()} - - def finish_grad_sync(self, tenant_id: str): - self._heartbeat(tenant_id) - self.model.finish_grad_sync(tenant_id) - - def clip_grad_norm(self, tenant_id: str) -> float: - self._heartbeat(tenant_id) - return self.model.clip_grad_norm(tenant_id=tenant_id).item() - - def step(self, tenant_id: str): - self._heartbeat(tenant_id) - self.model.step(tenant_id) - - def zero_grad(self, tenant_id: str): - self._heartbeat(tenant_id) - self.model.zero_grad(tenant_id) - - def lr_step(self, tenant_id: str): - self._heartbeat(tenant_id) - self.model.lr_step(tenant_id) - - def tenant_count(self) -> int: - """Get number of active tenants (does not expose tenant IDs for privacy).""" - return self.model.tenant_count() - - -# ============ FastAPI App ============ - -def create_app(server: MultiTenantServer) -> FastAPI: - """Create FastAPI app.""" - app = FastAPI(title="Multi-Tenant Megatron Server") - - def get_tenant(request: Request) -> str: - tid = request.headers.get("X-Tenant-ID") - if not tid: - raise HTTPException(400, "Missing X-Tenant-ID") - return tid - - @app.post("/initialize", response_model=TenantResponse) - def initialize(body: InitializeRequest): - tid = server.initialize(body) - return TenantResponse(tenant_id=tid) - - @app.post("/finalize", response_model=TenantResponse) - def finalize(request: Request): - server.finalize(get_tenant(request)) - return TenantResponse() - - @app.post("/forward_backward", response_model=TenantResponse) - def forward_backward(request: Request, body: InputsRequest): - data = server.forward_backward(get_tenant(request), body.inputs) - return TenantResponse(data=data) - - @app.post("/finish_grad_sync", response_model=TenantResponse) - def finish_grad_sync(request: Request): - server.finish_grad_sync(get_tenant(request)) - return TenantResponse() - - @app.post("/clip_grad_norm", response_model=TenantResponse) - def clip_grad_norm(request: Request): - norm = server.clip_grad_norm(get_tenant(request)) - return TenantResponse(data=norm) - - @app.post("/step", response_model=TenantResponse) - def step(request: Request): - server.step(get_tenant(request)) - return TenantResponse() - - @app.post("/zero_grad", response_model=TenantResponse) - def zero_grad(request: Request): - server.zero_grad(get_tenant(request)) - return TenantResponse() - - @app.post("/lr_step", response_model=TenantResponse) - def lr_step(request: Request): - server.lr_step(get_tenant(request)) - return TenantResponse() - - @app.get("/stats") - def stats(): - """Server statistics (does not expose tenant IDs for privacy).""" - return {"tenant_count": server.tenant_count()} - - @app.get("/health") - def health(): - return {"status": "healthy"} - - return app - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--model-id", required=True) - parser.add_argument("--tp", type=int, default=1) - parser.add_argument("--host", default="0.0.0.0") - parser.add_argument("--port", type=int, default=8080) - args = parser.parse_args() - - logging.basicConfig(level=logging.INFO) - - server = MultiTenantServer(args.model_id, args.tp) - server.setup() - - import uvicorn - uvicorn.run(create_app(server), host=args.host, port=args.port) - - -if __name__ == "__main__": - main() diff --git a/src/twinkle/megatron/distributed/__init__.py b/src/twinkle/megatron/distributed/__init__.py deleted file mode 100644 index 364bd702..00000000 --- a/src/twinkle/megatron/distributed/__init__.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright (c) twinkle authors. All rights reserved. -""" -[WIP] -Distributed training utilities for multi-tenant Megatron LoRA. - -Core components: -- tenant_context: ContextVar-based tenant management -- tenant_manager: Tenant lifecycle (adapters, optimizers) -- multi_tenant_ddp: Per-tenant gradient buffers and sync -""" - -from .multi_tenant_ddp import MultiTenantLoRADDP, TenantDDPState -from .tenant_context import (TenantInfo, generate_tenant_id, - get_current_tenant, require_tenant, - set_current_tenant, tenant_scope) -from .tenant_manager import TenantManager, TenantState -from .clock_cycle_scheduler import ( - ClockCycleScheduler, - ClockCycleTrainingClient, - CycleStats, - RequestType, - TrainingRequest, - ModelInterfaceError, - validate_model_interface, -) - -__all__ = [ - # Context - 'get_current_tenant', - 'set_current_tenant', - 'require_tenant', - 'tenant_scope', - 'generate_tenant_id', - 'TenantInfo', - # Manager - 'TenantManager', - 'TenantState', - # DDP (Twinkle mode) - 'MultiTenantLoRADDP', - 'TenantDDPState', - # Clock Cycle Scheduler - 'ClockCycleScheduler', - 'ClockCycleTrainingClient', - 'CycleStats', - 'RequestType', - 'TrainingRequest', - 'ModelInterfaceError', - 'validate_model_interface', -] diff --git a/src/twinkle/megatron/distributed/clock_cycle_scheduler.py b/src/twinkle/megatron/distributed/clock_cycle_scheduler.py deleted file mode 100644 index b1ae11b9..00000000 --- a/src/twinkle/megatron/distributed/clock_cycle_scheduler.py +++ /dev/null @@ -1,598 +0,0 @@ -# Copyright (c) twinkle authors. All rights reserved. -""" -Clock Cycle Scheduler for multi-tenant training. - -This module implements a time-sharing scheduler that batches requests -from multiple tenants into fixed clock cycles. - -## Key Concepts - -- Clock Cycle: Fixed time interval where all pending requests are processed -- Request Queue: Collects requests between cycles -- Batched Grad Sync: One communication round for all tenants (efficient) -- Gradient Isolation: Each tenant has separate LoRA params, no gradient overwrite -""" - -import logging -import threading -import time -from collections import defaultdict -from dataclasses import dataclass, field -from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Tuple -from concurrent.futures import Future - -import torch -import torch.nn as nn -import torch.distributed as dist - -logger = logging.getLogger(__name__) - - -# ============ Request Types ============ - -class RequestType(Enum): - """Type of training request.""" - FORWARD_BACKWARD = "forward_backward" - OPTIM_STEP = "optim_step" - ZERO_GRAD = "zero_grad" - - -@dataclass -class TrainingRequest: - """A training request from a tenant.""" - tenant_id: str - request_type: RequestType - inputs: Any = None - labels: Any = None - kwargs: Dict[str, Any] = field(default_factory=dict) - future: Optional[Future] = None - submitted_at: float = field(default_factory=time.time) - - -# ============ Cycle Statistics ============ - -@dataclass -class CycleStats: - """Statistics for a clock cycle.""" - cycle_id: int - start_time: float - end_time: float - num_tenants: int - num_requests: int - total_samples: int - forward_time: float - backward_time: float - grad_sync_time: float - optim_step_time: float - - @property - def duration(self) -> float: - return self.end_time - self.start_time - - @property - def gpu_active_time(self) -> float: - return self.forward_time + self.backward_time + self.optim_step_time - - @property - def gpu_utilization(self) -> float: - if self.duration > 0: - return self.gpu_active_time / self.duration - return 0.0 - - @property - def samples_per_second(self) -> float: - if self.duration > 0: - return self.total_samples / self.duration - return 0.0 - - -# ============ Model Interface Requirements ============ - -class ModelInterfaceError(Exception): - """Raised when model doesn't implement required interface.""" - pass - - -def validate_model_interface(model: nn.Module) -> None: - """ - Validate that model implements the required interface. - - Required methods: - - scope(tenant_id) -> context manager - - zero_grad(tenant_id) -> None - - step(tenant_id) -> None - - __call__(inputs) -> output (forward) - - Optional methods: - - clip_grad_norm(tenant_id, max_norm) -> None - - finish_grad_sync(tenant_id) -> None - - finish_grad_sync_batched(tenant_ids) -> None - """ - required = ['scope', 'zero_grad', 'step'] - missing = [m for m in required if not hasattr(model, m)] - - if missing: - raise ModelInterfaceError( - f"Model must implement: {required}. Missing: {missing}" - ) - - -# ============ Gradient Synchronization ============ - -class GradientSynchronizer: - """ - Handles gradient synchronization for multiple tenants. - - For distributed training, this batches gradient communication - to reduce the number of NCCL calls. - """ - - def __init__(self, model: nn.Module): - self.model = model - - def sync_individual(self, tenant_id: str) -> float: - """Synchronize gradients for a single tenant.""" - if not dist.is_initialized(): - return 0.0 - - t0 = time.time() - - if hasattr(self.model, 'finish_grad_sync'): - self.model.finish_grad_sync(tenant_id) - - return time.time() - t0 - - def sync_batched(self, tenant_ids: List[str]) -> float: - """ - Synchronize gradients for multiple tenants. - - Uses batched sync if model supports it, otherwise falls back - to individual sync. - """ - if not tenant_ids: - return 0.0 - - t0 = time.time() - - if hasattr(self.model, 'finish_grad_sync_batched'): - # Optimized: one call for all tenants - self.model.finish_grad_sync_batched(tenant_ids) - elif dist.is_initialized(): - # Fallback: sync each tenant individually - for tenant_id in tenant_ids: - self.sync_individual(tenant_id) - - return time.time() - t0 - - -# ============ Clock Cycle Scheduler ============ - -class ClockCycleScheduler: - """ - Clock cycle scheduler for multi-tenant training. - - Collects requests from multiple tenants and executes them in batched - clock cycles. While computation is per-tenant serial (due to LoRA - architecture), communication is batched for efficiency. - - ## Benefits - - 1. **Unified Scheduling**: All tenants processed in fixed cycles - 2. **Batched Communication**: One grad sync round for all tenants - 3. **Fair Scheduling**: All pending requests processed together - 4. **Predictable Latency**: Fixed cycle interval - - ## Usage - - ```python - scheduler = ClockCycleScheduler(model, cycle_interval_ms=100) - scheduler.start() - - # From multiple clients - future1 = scheduler.submit_forward_backward('tenant_a', inputs_a) - future2 = scheduler.submit_forward_backward('tenant_b', inputs_b) - - result1 = future1.result() - result2 = future2.result() - - scheduler.stop() - ``` - """ - - def __init__( - self, - model: nn.Module, - cycle_interval_ms: float = 100.0, - loss_fn: Optional[Callable] = None, - ): - """ - Initialize the scheduler. - - Args: - model: The multi-tenant model. Must implement: - - scope(tenant_id) -> context manager - - zero_grad(tenant_id) -> None - - step(tenant_id) -> None - - __call__(inputs) -> output - - cycle_interval_ms: Clock cycle interval in milliseconds. - loss_fn: Loss function (output, labels) -> loss. - Default: output.mean() - - Raises: - ModelInterfaceError: If model doesn't implement required methods. - """ - # Validate model interface - validate_model_interface(model) - - self.model = model - self.cycle_interval = cycle_interval_ms / 1000.0 - self.loss_fn = loss_fn or self._default_loss_fn - - # Gradient synchronizer - self._grad_sync = GradientSynchronizer(model) - - # Request queue (thread-safe) - self._queue_lock = threading.Lock() - self._request_queue: Dict[str, List[TrainingRequest]] = defaultdict(list) - - # Cycle management - self._running = False - self._cycle_thread: Optional[threading.Thread] = None - self._current_cycle_id = 0 - - # Statistics - self._stats: List[CycleStats] = [] - self._stats_lock = threading.Lock() - - def _default_loss_fn(self, output: torch.Tensor, labels: Any) -> torch.Tensor: - """Default loss function (mean of output).""" - if isinstance(output, torch.Tensor): - return output.mean() - raise ValueError(f"Cannot compute loss on {type(output)}, provide loss_fn") - - def start(self): - """Start the clock cycle loop.""" - if self._running: - return - - self._running = True - self._cycle_thread = threading.Thread( - target=self._cycle_loop, - name="ClockCycleLoop", - daemon=True, - ) - self._cycle_thread.start() - logger.info(f"Clock cycle scheduler started (interval={self.cycle_interval*1000:.0f}ms)") - - def stop(self): - """Stop the clock cycle loop.""" - self._running = False - if self._cycle_thread: - self._cycle_thread.join(timeout=5.0) - self._cycle_thread = None - logger.info("Clock cycle scheduler stopped") - - def submit_forward_backward( - self, - tenant_id: str, - inputs: Any, - labels: Any = None, - **kwargs, - ) -> Future: - """ - Submit a forward-backward request. - - Returns immediately with a Future containing the result. - """ - future = Future() - request = TrainingRequest( - tenant_id=tenant_id, - request_type=RequestType.FORWARD_BACKWARD, - inputs=inputs, - labels=labels, - kwargs=kwargs, - future=future, - ) - - with self._queue_lock: - self._request_queue[tenant_id].append(request) - - return future - - def submit_optim_step(self, tenant_id: str, **kwargs) -> Future: - """Submit an optimizer step request.""" - future = Future() - request = TrainingRequest( - tenant_id=tenant_id, - request_type=RequestType.OPTIM_STEP, - kwargs=kwargs, - future=future, - ) - - with self._queue_lock: - self._request_queue[tenant_id].append(request) - - return future - - def submit_zero_grad(self, tenant_id: str) -> Future: - """Submit a zero_grad request.""" - future = Future() - request = TrainingRequest( - tenant_id=tenant_id, - request_type=RequestType.ZERO_GRAD, - future=future, - ) - - with self._queue_lock: - self._request_queue[tenant_id].append(request) - - return future - - def _cycle_loop(self): - """Main clock cycle loop.""" - while self._running: - cycle_start = time.time() - - # Collect pending requests (deep copy for thread safety) - with self._queue_lock: - pending = { - tenant_id: list(reqs) - for tenant_id, reqs in self._request_queue.items() - if reqs - } - self._request_queue.clear() - - # Execute cycle if there are requests - if pending: - self._execute_cycle(pending) - - # Wait for next cycle - elapsed = time.time() - cycle_start - sleep_time = max(0, self.cycle_interval - elapsed) - if sleep_time > 0: - time.sleep(sleep_time) - - def _execute_cycle(self, requests: Dict[str, List[TrainingRequest]]): - """ - Execute one clock cycle. - - Phases: - 1. Forward-backward for each tenant (serial due to LoRA architecture) - 2. Batched gradient synchronization (efficient) - 3. Optimizer step for each tenant - - Note: Forward-backward is serial per tenant because each tenant's - LoRA weights are embedded in every layer. There's no way to "merge" - computation across tenants for Transformer models. - - However, gradient sync is batched, reducing communication overhead. - """ - cycle_start = time.time() - self._current_cycle_id += 1 - cycle_id = self._current_cycle_id - - # Group requests by type - fwd_bwd_reqs: Dict[str, TrainingRequest] = {} - optim_step_reqs: Dict[str, TrainingRequest] = {} - zero_grad_reqs: Dict[str, TrainingRequest] = {} - - for tenant_id, reqs in requests.items(): - for req in reqs: - if req.request_type == RequestType.FORWARD_BACKWARD: - fwd_bwd_reqs[tenant_id] = req - elif req.request_type == RequestType.OPTIM_STEP: - optim_step_reqs[tenant_id] = req - elif req.request_type == RequestType.ZERO_GRAD: - zero_grad_reqs[tenant_id] = req - - num_tenants = len(set(fwd_bwd_reqs.keys()) | set(optim_step_reqs.keys())) - num_requests = sum(len(reqs) for reqs in requests.values()) - - logger.debug(f"Cycle {cycle_id}: {num_tenants} tenants, {num_requests} requests") - - # Tracking - successful_tenants = [] - failed_tenants = [] - total_samples = 0 - forward_time = 0.0 - backward_time = 0.0 - grad_sync_time = 0.0 - optim_step_time = 0.0 - - try: - # ============ PHASE 1: Forward-Backward (per tenant) ============ - # Each tenant's forward-backward is independent because: - # 1. Each tenant has separate LoRA parameters - # 2. Gradients accumulate to each tenant's own LoRA params - # 3. No gradient overwrite between tenants - - for tenant_id, req in fwd_bwd_reqs.items(): - try: - inputs = req.inputs - labels = req.labels - - # Get batch size for stats - if isinstance(inputs, torch.Tensor): - batch_size = inputs.size(0) - elif isinstance(inputs, dict) and 'input_ids' in inputs: - batch_size = inputs['input_ids'].size(0) - else: - batch_size = 1 - - total_samples += batch_size - - with self.model.scope(tenant_id): - # Forward - t0 = time.time() - output = self.model(inputs) - forward_time += time.time() - t0 - - # Loss - loss = self.loss_fn(output, labels) - - # Backward - # Note: Each tenant's LoRA params are separate, - # so loss.backward() accumulates gradients to - # this tenant's params only - no overwrite - t0 = time.time() - loss.backward() - backward_time += time.time() - t0 - - # Record result - result = { - 'loss': loss.item() if hasattr(loss, 'item') else float(loss), - 'cycle_id': cycle_id, - 'batch_size': batch_size, - } - req.future.set_result(result) - successful_tenants.append(tenant_id) - - except Exception as e: - logger.error(f"Tenant {tenant_id} forward-backward failed: {e}") - failed_tenants.append(tenant_id) - req.future.set_exception(e) - - # Clean up failed tenant's gradient state - try: - self.model.zero_grad(tenant_id) - except Exception: - pass - - # ============ PHASE 2: Batched Gradient Sync ============ - # This is where we get efficiency: one sync round for all tenants - if successful_tenants: - grad_sync_time = self._grad_sync.sync_batched(successful_tenants) - - # ============ PHASE 3: Optimizer Step ============ - for tenant_id, req in optim_step_reqs.items(): - try: - t0 = time.time() - - # Clip gradients (optional) - if hasattr(self.model, 'clip_grad_norm'): - self.model.clip_grad_norm(tenant_id=tenant_id) - - # Optimizer step - self.model.step(tenant_id) - - # Zero grad after step - self.model.zero_grad(tenant_id) - - optim_step_time += time.time() - t0 - req.future.set_result({'cycle_id': cycle_id}) - - except Exception as e: - logger.error(f"Tenant {tenant_id} optimizer step failed: {e}") - req.future.set_exception(e) - - # ============ PHASE 4: Standalone zero_grad ============ - for tenant_id, req in zero_grad_reqs.items(): - if tenant_id not in optim_step_reqs: - try: - self.model.zero_grad(tenant_id) - req.future.set_result(None) - except Exception as e: - req.future.set_exception(e) - - except Exception as e: - logger.exception(f"Cycle {cycle_id} failed: {e}") - for reqs in requests.values(): - for req in reqs: - if not req.future.done(): - req.future.set_exception(e) - - # Record stats - cycle_end = time.time() - stats = CycleStats( - cycle_id=cycle_id, - start_time=cycle_start, - end_time=cycle_end, - num_tenants=num_tenants, - num_requests=num_requests, - total_samples=total_samples, - forward_time=forward_time, - backward_time=backward_time, - grad_sync_time=grad_sync_time, - optim_step_time=optim_step_time, - ) - - with self._stats_lock: - self._stats.append(stats) - - logger.debug( - f"Cycle {cycle_id} completed: duration={stats.duration*1000:.1f}ms, " - f"tenants={num_tenants}, samples={total_samples}" - ) - - def get_stats(self) -> List[CycleStats]: - """Get all cycle statistics.""" - with self._stats_lock: - return list(self._stats) - - def get_summary_stats(self) -> Dict[str, float]: - """Get summary statistics.""" - with self._stats_lock: - if not self._stats: - return {} - - total_cycles = len(self._stats) - total_duration = sum(s.duration for s in self._stats) - total_samples = sum(s.total_samples for s in self._stats) - total_forward = sum(s.forward_time for s in self._stats) - total_backward = sum(s.backward_time for s in self._stats) - total_sync = sum(s.grad_sync_time for s in self._stats) - total_optim = sum(s.optim_step_time for s in self._stats) - total_gpu_time = total_forward + total_backward + total_optim - - return { - 'total_cycles': total_cycles, - 'total_duration': total_duration, - 'total_samples': total_samples, - 'total_gpu_time': total_gpu_time, - 'total_comm_time': total_sync, - 'avg_cycle_duration': total_duration / total_cycles, - 'gpu_utilization': total_gpu_time / total_duration if total_duration > 0 else 0, - 'throughput_samples_per_sec': total_samples / total_duration if total_duration > 0 else 0, - } - - -# ============ Training Client ============ - -class ClockCycleTrainingClient: - """ - Client for clock cycle scheduler. - - Provides a simple API for submitting training requests. - """ - - def __init__(self, scheduler: ClockCycleScheduler, tenant_id: str): - self.scheduler = scheduler - self.tenant_id = tenant_id - - def forward_backward(self, inputs: Any, labels: Any = None) -> Future: - """Submit forward-backward (returns Future).""" - return self.scheduler.submit_forward_backward(self.tenant_id, inputs, labels) - - def optim_step(self) -> Future: - """Submit optimizer step (returns Future).""" - return self.scheduler.submit_optim_step(self.tenant_id) - - def zero_grad(self) -> Future: - """Submit zero_grad (returns Future).""" - return self.scheduler.submit_zero_grad(self.tenant_id) - - def train_step(self, inputs: Any, labels: Any = None) -> Dict[str, Any]: - """ - Execute a complete training step (blocking). - - Submits forward_backward and optim_step in the same cycle. - """ - fwd_future = self.forward_backward(inputs, labels) - opt_future = self.optim_step() - - result = fwd_future.result() - opt_future.result() - - return result diff --git a/src/twinkle/megatron/distributed/multi_tenant_ddp.py b/src/twinkle/megatron/distributed/multi_tenant_ddp.py deleted file mode 100644 index e046555a..00000000 --- a/src/twinkle/megatron/distributed/multi_tenant_ddp.py +++ /dev/null @@ -1,398 +0,0 @@ -# Copyright (c) twinkle authors. All rights reserved. -""" -Multi-Tenant LoRA DDP for Megatron models. - -This module provides a DDP implementation for multi-tenant LoRA training, -inheriting from Megatron's DistributedDataParallel. - -Key Design: -1. Inherits from MegatronDDP for code reuse -2. Overrides buffer/bucket creation to be per-tenant -3. Uses ContextVar for automatic tenant resolution -4. Tenant lifecycle managed by TenantManager (separate concern) -""" - -import logging -from contextlib import contextmanager -from dataclasses import dataclass, field -from typing import Dict, List, Optional - -import torch -import torch.distributed as dist -import torch.nn as nn - -from .tenant_context import get_current_tenant, require_tenant, tenant_scope - -logger = logging.getLogger(__name__) - -try: - from megatron.core import parallel_state as mpu - from megatron.core.distributed import DistributedDataParallel as MegatronDDP - from megatron.core.distributed.distributed_data_parallel_config import DistributedDataParallelConfig - from megatron.core.distributed.param_and_grad_buffer import ( - _ParamAndGradBuffer, - partition_buckets, - ) - from megatron.core.process_groups_config import ProcessGroupCollection - from megatron.core.transformer.transformer_config import TransformerConfig - MEGATRON_AVAILABLE = True -except ImportError: - MEGATRON_AVAILABLE = False - - # Fallback for type hints - class MegatronDDP(nn.Module): - pass - - -@dataclass -class TenantDDPState: - """Per-tenant DDP state: buffers, bucket groups, hooks.""" - tenant_id: str - params: List[nn.Parameter] = field(default_factory=list) - buffers: List = field(default_factory=list) - bucket_groups: List = field(default_factory=list) - param_to_bucket_group: Dict[nn.Parameter, - object] = field(default_factory=dict) - grad_accs: List = field(default_factory=list) - process_group: Optional[dist.ProcessGroup] = None - - -class MultiTenantLoRADDP(MegatronDDP): - """ - Multi-Tenant LoRA DDP inheriting from MegatronDDP. - - This class extends MegatronDDP to support per-tenant gradient buffers - and communication. The key difference is that instead of creating - buffers for all parameters at init, buffers are created dynamically - for each tenant. - - Comparison with MegatronDDP: - - MegatronDDP: Creates buffers for all requires_grad=True params at __init__ - - MultiTenantLoRADDP: Creates buffers per-tenant when add_tenant is called - - Usage: - >>> # Create with frozen base model (no trainable params yet) - >>> ddp = MultiTenantLoRADDP(config, ddp_config, model) - >>> - >>> # Add tenant (creates buffers for their LoRA params) - >>> ddp.add_tenant('tenant_a', params_a, process_group_a) - >>> - >>> # Training uses current tenant context - >>> with tenant_scope('tenant_a'): - ... ddp.zero_grad_buffer() # Zeros tenant_a's buffers - ... output = ddp(input) - ... loss.backward() - ... ddp.finish_grad_sync() # Syncs tenant_a's gradients - >>> - >>> # Remove tenant - >>> ddp.remove_tenant('tenant_a') - """ - def __init__( - self, - config: 'TransformerConfig', - ddp_config: 'DistributedDataParallelConfig', - module: nn.Module, - disable_bucketing: bool = False, - pg_collection: Optional['ProcessGroupCollection'] = None, - ): - """ - Initialize MultiTenantLoRADDP. - - Unlike MegatronDDP, this does NOT create buffers at init. - Buffers are created per-tenant via add_tenant(). - - Args: - config: Transformer config. - ddp_config: DDP config. - module: Model (base model should be frozen). - disable_bucketing: Disable bucketing. - pg_collection: Process group collection. - """ - if not MEGATRON_AVAILABLE: - raise ImportError('Megatron-Core is required') - - # Skip MegatronDDP's buffer creation by temporarily setting all params to not require grad - original_requires_grad = {} - for name, param in module.named_parameters(): - original_requires_grad[name] = param.requires_grad - param.requires_grad = False - - # Call parent init (will create empty buffers since no params require grad) - super().__init__( - config=config, - ddp_config=ddp_config, - module=module, - disable_bucketing=disable_bucketing, - pg_collection=pg_collection, - ) - - # Restore requires_grad - for name, param in module.named_parameters(): - param.requires_grad = original_requires_grad[name] - - # Per-tenant state - self._tenant_states: Dict[str, TenantDDPState] = {} - - logger.info('MultiTenantLoRADDP initialized (no buffers yet)') - - def add_tenant( - self, - tenant_id: str, - params: List[nn.Parameter], - process_group: Optional[dist.ProcessGroup] = None, - param_names: Optional[Dict[nn.Parameter, str]] = None, - ): - """ - Add a tenant with their gradient buffers. - - This creates per-tenant buffers and hooks, similar to what - MegatronDDP.__init__ does but scoped to this tenant. - - Args: - tenant_id: Unique tenant ID. - params: Trainable parameters for this tenant. - process_group: Process group for gradient sync. - param_names: Param to name mapping for debugging. - """ - if tenant_id in self._tenant_states: - raise ValueError(f"Tenant '{tenant_id}' already exists") - - if not params: - raise ValueError('No parameters provided') - - process_group = process_group or self.intra_dp_cp_group - param_names = param_names or {} - - # Build param_names if not provided - if not param_names: - for name, param in self.module.named_parameters(): - if param in params: - param_names[param] = name - - # Create tenant state - state = TenantDDPState( - tenant_id=tenant_id, - params=params, - process_group=process_group, - ) - - # Initialize grad flags - for param in params: - param.grad_added_to_main_grad = False - - # Create buffers - self._create_tenant_buffers(state, param_names) - - # Register hooks - self._register_tenant_hooks(state) - - self._tenant_states[tenant_id] = state - - logger.info(f"Added tenant '{tenant_id}' with {len(params)} params, " - f'{len(state.bucket_groups)} bucket groups') - - def _create_tenant_buffers( - self, - state: TenantDDPState, - param_names: Dict[nn.Parameter, str], - ): - """Create gradient buffers for a tenant.""" - # Group by dtype - param_and_grad_dtype_to_params = {} - param_and_grad_dtype_to_indices = {} - - for param in state.params: - param_dtype = param.dtype - grad_dtype = torch.float if self.ddp_config.grad_reduce_in_fp32 else param.dtype - - key = (param_dtype, grad_dtype) - if key not in param_and_grad_dtype_to_params: - param_and_grad_dtype_to_params[key] = [] - param_and_grad_dtype_to_indices[key] = [] - - param_and_grad_dtype_to_params[key].append(param) - param_and_grad_dtype_to_indices[key].append( - len(param_and_grad_dtype_to_params[key]) - 1) - - # Calculate gradient scaling - if self.config.calculate_per_token_loss: - gradient_scaling_factor = 1.0 - elif self.ddp_config.average_in_collective: - gradient_scaling_factor = 1.0 - else: - gradient_scaling_factor = 1.0 / state.process_group.size() - - # ProcessGroupCollection for buffer creation - pg_collection = ProcessGroupCollection() - pg_collection.tp = self.tp_group - pg_collection.dp_cp = state.process_group - - # Create buffers - for (param_dtype, - grad_dtype), params in param_and_grad_dtype_to_params.items(): - indices = param_and_grad_dtype_to_indices[(param_dtype, - grad_dtype)] - - buffer = _ParamAndGradBuffer( - self.ddp_config, - param_dtype, - grad_dtype, - params, - state.process_group, - self.bucket_size, - param_names, - gradient_scaling_factor, - indices, - getattr(self.ddp_config, 'nccl_ub', False), - pg_collection, - ) - state.buffers.append(buffer) - - # Create bucket groups - state.bucket_groups = partition_buckets( - state.buffers, - force_single_bucket_group=(self.bucket_size is None), - ) - - # Build param to bucket group mapping - for bucket_group in state.bucket_groups: - for bucket in bucket_group.buckets: - for param in bucket.params_list: - state.param_to_bucket_group[param] = bucket_group - - def _register_tenant_hooks(self, state: TenantDDPState): - """Register backward hooks for a tenant.""" - for param in state.params: - if param not in state.param_to_bucket_group: - continue - - param_tmp = param.expand_as(param) - grad_acc = param_tmp.grad_fn.next_functions[0][0] - grad_acc.register_hook( - self._make_tenant_backward_hook(param, state)) - state.grad_accs.append(grad_acc) - - def _make_tenant_backward_hook(self, param: nn.Parameter, - state: TenantDDPState): - """Create backward hook for a tenant's parameter.""" - def hook(*unused): - if param in state.param_to_bucket_group: - if param.grad is not None and not param.grad_added_to_main_grad: - param.main_grad.add_(param.grad.data) - param.grad = None - - if self.ddp_config.overlap_grad_reduce: - bucket_group = state.param_to_bucket_group[param] - if bucket_group.is_last_microbatch: - bucket_group.register_grad_ready(param) - - return hook - - def remove_tenant(self, tenant_id: str): - """Remove a tenant and cleanup their resources.""" - if tenant_id not in self._tenant_states: - raise KeyError(f"Tenant '{tenant_id}' not found") - - state = self._tenant_states.pop(tenant_id) - - # Clear hooks - state.grad_accs.clear() - - # Clear buffers - state.buffers.clear() - state.bucket_groups.clear() - state.param_to_bucket_group.clear() - - # Clear param attributes - for param in state.params: - if hasattr(param, 'main_grad'): - delattr(param, 'main_grad') - if hasattr(param, 'grad_added_to_main_grad'): - delattr(param, 'grad_added_to_main_grad') - - logger.info(f"Removed tenant '{tenant_id}'") - - def _get_tenant_state(self, - tenant_id: Optional[str] = None) -> TenantDDPState: - """Get state for tenant (uses context if not specified).""" - tenant_id = tenant_id or require_tenant() - if tenant_id not in self._tenant_states: - raise KeyError(f"Tenant '{tenant_id}' not registered") - return self._tenant_states[tenant_id] - - # ========== Override MegatronDDP methods to be tenant-aware ========== - - @contextmanager - def no_sync(self, tenant_id: Optional[str] = None): - """Disable gradient sync for a tenant.""" - state = self._get_tenant_state(tenant_id) - for bucket_group in state.bucket_groups: - bucket_group.is_last_microbatch = False - try: - yield - finally: - for bucket_group in state.bucket_groups: - bucket_group.is_last_microbatch = True - - def start_grad_sync(self, tenant_id: Optional[str] = None): - """Start gradient sync for a tenant.""" - state = self._get_tenant_state(tenant_id) - for bucket_group in state.bucket_groups: - bucket_group.start_grad_sync() - - def finish_grad_sync(self, tenant_id: Optional[str] = None): - """Finish gradient sync for a tenant.""" - state = self._get_tenant_state(tenant_id) - for bucket_group in state.bucket_groups: - bucket_group.finish_grad_sync() - - def zero_grad_buffer(self, tenant_id: Optional[str] = None): - """Zero gradient buffers for a tenant.""" - state = self._get_tenant_state(tenant_id) - - for param in state.params: - param.grad_added_to_main_grad = False - - for buffer in state.buffers: - buffer.reset() - - for bucket_group in state.bucket_groups: - bucket_group.reset() - - def scale_gradients(self, - scaling_factor: float, - tenant_id: Optional[str] = None): - """Scale gradients for a tenant.""" - state = self._get_tenant_state(tenant_id) - for buffer in state.buffers: - buffer.scale_gradients(scaling_factor) - - def broadcast_params(self, tenant_id: Optional[str] = None): - """Broadcast parameters for a tenant.""" - state = self._get_tenant_state(tenant_id) - for param in state.params: - dist.broadcast( - param.data, - src=dist.get_global_rank(state.process_group, 0), - group=state.process_group, - ) - - # ========== Utility ========== - - def has_tenant(self, tenant_id: str) -> bool: - """Check if tenant exists.""" - return tenant_id in self._tenant_states - - def list_tenants(self) -> List[str]: - """List all tenants.""" - return list(self._tenant_states.keys()) - - def get_tenant_params(self, - tenant_id: Optional[str] = None - ) -> List[nn.Parameter]: - """Get parameters for a tenant (requires valid tenant context).""" - state = self._get_tenant_state(tenant_id) - return state.params - - # Note: list_tenants() intentionally not exposed to prevent - # information leakage between tenants. Use has_tenant() instead. diff --git a/src/twinkle/megatron/distributed/tenant_context.py b/src/twinkle/megatron/distributed/tenant_context.py deleted file mode 100644 index 8ec5b2f6..00000000 --- a/src/twinkle/megatron/distributed/tenant_context.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright (c) twinkle authors. All rights reserved. -""" -Tenant context management using ContextVar. - -This module provides process-level tenant context that automatically -propagates through async calls and threads, eliminating the need to -manually pass tenant_id to every method. -""" - -import contextvars -import uuid -from contextlib import contextmanager -from dataclasses import dataclass, field -from typing import Any, Callable, Dict, List, Optional, TypeVar - -import torch.distributed as dist - -# Global ContextVar for current tenant - process level -_current_tenant: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar( - 'current_tenant', default=None -) - - -def get_current_tenant() -> Optional[str]: - """Get the current tenant ID from context.""" - return _current_tenant.get() - - -def set_current_tenant(tenant_id: Optional[str]) -> contextvars.Token: - """Set the current tenant ID in context.""" - return _current_tenant.set(tenant_id) - - -def require_tenant() -> str: - """Get current tenant ID, raising error if not set.""" - tenant_id = _current_tenant.get() - if tenant_id is None: - raise RuntimeError( - "No tenant context set. Use 'with tenant_scope(tenant_id):' or " - "call 'initialize()' first." - ) - return tenant_id - - -@contextmanager -def tenant_scope(tenant_id: str): - """ - Context manager to set the current tenant for a block of code. - - Example: - >>> with tenant_scope('user_a'): - ... model.forward(input) # Uses user_a's LoRA - ... loss.backward() - ... ddp.finish_grad_sync() # Only syncs user_a's gradients - """ - token = _current_tenant.set(tenant_id) - try: - yield tenant_id - finally: - _current_tenant.reset(token) - - -def generate_tenant_id() -> str: - """Generate a unique tenant ID.""" - return str(uuid.uuid4())[:8] - - -@dataclass -class TenantInfo: - """ - Information about a registered tenant. - - This is a lightweight dataclass that stores tenant metadata, - separate from DDP-specific state. - """ - tenant_id: str - adapter_name: str - process_group: Optional[dist.ProcessGroup] = None - metadata: Dict[str, Any] = field(default_factory=dict) - - -F = TypeVar('F', bound=Callable) - - -def with_tenant_context(func: F) -> F: - """ - Decorator that automatically uses the current tenant context. - - The decorated function should have an optional 'tenant_id' parameter. - If not provided, it will use the current tenant from context. - - Example: - >>> @with_tenant_context - ... def finish_grad_sync(self, tenant_id: Optional[str] = None): - ... # tenant_id is automatically set from context if None - ... ... - """ - import functools - - @functools.wraps(func) - def wrapper(*args, tenant_id: Optional[str] = None, **kwargs): - if tenant_id is None: - tenant_id = require_tenant() - return func(*args, tenant_id=tenant_id, **kwargs) - - return wrapper # type: ignore diff --git a/src/twinkle/megatron/distributed/tenant_manager.py b/src/twinkle/megatron/distributed/tenant_manager.py deleted file mode 100644 index 260709a4..00000000 --- a/src/twinkle/megatron/distributed/tenant_manager.py +++ /dev/null @@ -1,270 +0,0 @@ -# Copyright (c) twinkle authors. All rights reserved. -""" -Tenant Manager for multi-tenant LoRA training. - -This module provides tenant lifecycle management including: -- Tenant registration/deregistration -- LoRA adapter management (via PEFT) -- Optimizer/scheduler creation -- Tenant context switching - -""" - -import logging -from contextlib import contextmanager -from dataclasses import dataclass, field -from typing import Any, Callable, Dict, List, Optional, Type - -import torch -import torch.distributed as dist -import torch.nn as nn - -from .tenant_context import ( - generate_tenant_id, - get_current_tenant, - require_tenant, - set_current_tenant, - tenant_scope, -) - -logger = logging.getLogger(__name__) - - -from peft import LoraConfig, PeftModel - -@dataclass -class TenantState: - """ - State for a single tenant. - - Contains: - - Identity: tenant_id, adapter_name - - Training: optimizer, scheduler, params - - Config: gradient accumulation, max grad norm - """ - tenant_id: str - adapter_name: str - - # Parameters - params: List[nn.Parameter] = field(default_factory=list) - param_names: Dict[nn.Parameter, str] = field(default_factory=dict) - - # Training components - optimizer: Optional[torch.optim.Optimizer] = None - scheduler: Optional[Any] = None - - # Training config - gradient_accumulation_steps: int = 1 - max_grad_norm: float = 1.0 - - # Process group for this tenant - process_group: Optional[dist.ProcessGroup] = None - - -class TenantManager: - """ - Manages tenant lifecycle for multi-tenant training. - - Responsibilities: - 1. Tenant registration/deregistration - 2. LoRA adapter management - 3. Optimizer/scheduler creation - 4. Tenant context switching - - This class is decoupled from DDP - it only manages tenant metadata - and training components, not gradient buffers or communication. - - Example: - >>> manager = TenantManager(model) - >>> - >>> # Initialize tenant - >>> tenant_id = manager.initialize( - ... lora_config=LoraConfig(r=8), - ... optimizer_cls=AdamW, - ... ) - >>> - >>> # Use tenant context - >>> with manager.scope(tenant_id): - ... # All operations use this tenant - ... pass - >>> - >>> # Cleanup - >>> manager.finalize(tenant_id) - """ - - def __init__( - self, - model: nn.Module, - default_process_group: Optional[dist.ProcessGroup] = None, - ): - """ - Initialize tenant manager. - - Args: - model: Model with LoRA structure. - default_process_group: Default process group for tenants. - """ - self.model = model - self.default_process_group = default_process_group - self._tenants: Dict[str, TenantState] = {} - - # Callbacks for DDP integration - self._on_add_callbacks: List[Callable[[TenantState], None]] = [] - self._on_remove_callbacks: List[Callable[[TenantState], None]] = [] - - def register_add_callback(self, callback: Callable[[TenantState], None]): - """Register callback to be called when tenant is added.""" - self._on_add_callbacks.append(callback) - - def register_remove_callback(self, callback: Callable[[TenantState], None]): - """Register callback to be called when tenant is removed.""" - self._on_remove_callbacks.append(callback) - - def initialize( - self, - lora_config: Optional['LoraConfig'] = None, - optimizer_cls: Type[torch.optim.Optimizer] = torch.optim.AdamW, - optimizer_kwargs: Optional[Dict[str, Any]] = None, - scheduler_cls: Optional[Type] = None, - scheduler_kwargs: Optional[Dict[str, Any]] = None, - gradient_accumulation_steps: int = 1, - max_grad_norm: float = 1.0, - process_group: Optional[dist.ProcessGroup] = None, - adapter_name: Optional[str] = None, - tenant_id: Optional[str] = None, - ) -> str: - """ - Initialize a new tenant. - - Args: - lora_config: LoRA configuration. - optimizer_cls: Optimizer class. - optimizer_kwargs: Optimizer arguments. - scheduler_cls: Scheduler class. - scheduler_kwargs: Scheduler arguments. - gradient_accumulation_steps: Steps to accumulate. - max_grad_norm: Max gradient norm for clipping. - process_group: Process group for gradient sync. - adapter_name: Adapter name (defaults to tenant_id). - tenant_id: Tenant ID (generated if not provided). - - Returns: - The tenant ID. - """ - tenant_id = tenant_id or generate_tenant_id() - adapter_name = adapter_name or tenant_id - process_group = process_group or self.default_process_group - - if tenant_id in self._tenants: - raise ValueError(f"Tenant '{tenant_id}' already exists") - - # Add LoRA adapter - if lora_config is not None and isinstance(self.model, PeftModel): - lora_config.modules_to_save = None - lora_config.bias = 'none' - self.model.add_adapter(adapter_name, lora_config) - logger.info(f"Added LoRA adapter '{adapter_name}'") - - # Find trainable params - params = [] - param_names = {} - - for name, param in self.model.named_parameters(): - if f'.{adapter_name}.' in name and 'lora_' in name: - param.requires_grad = True - params.append(param) - param_names[param] = name - - if not params: - logger.warning(f"No trainable params found for tenant '{tenant_id}'") - - # Create optimizer - optimizer_kwargs = optimizer_kwargs or {'lr': 1e-4} - optimizer = optimizer_cls(params, **optimizer_kwargs) if params else None - - # Create scheduler - scheduler = None - if scheduler_cls and optimizer: - scheduler_kwargs = scheduler_kwargs or {} - scheduler = scheduler_cls(optimizer, **scheduler_kwargs) - - # Create state - state = TenantState( - tenant_id=tenant_id, - adapter_name=adapter_name, - params=params, - param_names=param_names, - optimizer=optimizer, - scheduler=scheduler, - gradient_accumulation_steps=gradient_accumulation_steps, - max_grad_norm=max_grad_norm, - process_group=process_group, - ) - - self._tenants[tenant_id] = state - - # Notify callbacks (for DDP integration) - for callback in self._on_add_callbacks: - callback(state) - - # Set as current tenant - set_current_tenant(tenant_id) - - logger.info( - f"Initialized tenant '{tenant_id}' with {len(params)} params " - f"({sum(p.numel() for p in params):,} elements)" - ) - - return tenant_id - - def finalize(self, tenant_id: Optional[str] = None): - """ - Finalize a tenant and cleanup resources. - - Args: - tenant_id: Tenant to finalize. Uses current if None. - """ - tenant_id = tenant_id or get_current_tenant() - if not tenant_id or tenant_id not in self._tenants: - return - - state = self._tenants.pop(tenant_id) - - # Notify callbacks (for DDP cleanup) - for callback in self._on_remove_callbacks: - callback(state) - - # Remove adapter - if isinstance(self.model, PeftModel): - try: - self.model.delete_adapter(state.adapter_name) - except Exception as e: - logger.warning(f"Failed to delete adapter: {e}") - - # Clear context if current - if get_current_tenant() == tenant_id: - set_current_tenant(None) - - logger.info(f"Finalized tenant '{tenant_id}'") - - @contextmanager - def scope(self, tenant_id: Optional[str] = None): - """Context manager for tenant scope.""" - tenant_id = tenant_id or require_tenant() - with tenant_scope(tenant_id): - yield self.get(tenant_id) - - def get(self, tenant_id: Optional[str] = None) -> TenantState: - """Get tenant state.""" - tenant_id = tenant_id or require_tenant() - if tenant_id not in self._tenants: - raise KeyError(f"Tenant '{tenant_id}' not found") - return self._tenants[tenant_id] - - def has(self, tenant_id: str) -> bool: - """Check if tenant exists.""" - return tenant_id in self._tenants - - def count(self) -> int: - """Number of tenants (does not expose tenant IDs for privacy).""" - return len(self._tenants) diff --git a/src/twinkle/megatron/model/multi_tenant_megatron.py b/src/twinkle/megatron/model/multi_tenant_megatron.py deleted file mode 100644 index f56b37d2..00000000 --- a/src/twinkle/megatron/model/multi_tenant_megatron.py +++ /dev/null @@ -1,333 +0,0 @@ -# Copyright (c) twinkle authors. All rights reserved. -""" -Multi-Tenant Megatron Model for LoRA training. - -This module integrates TenantManager and MultiTenantLoRADDP to provide -a complete multi-tenant training solution. -""" - -import contextvars -import logging -import re -from contextlib import contextmanager -from typing import Any, Dict, List, Optional, Type - -import torch -import torch.distributed as dist -import torch.nn as nn - -from ..distributed.multi_tenant_ddp import MultiTenantLoRADDP -from ..distributed.tenant_context import (get_current_tenant, require_tenant, - set_current_tenant, tenant_scope) -from ..distributed.tenant_manager import TenantManager, TenantState - -logger = logging.getLogger(__name__) - -try: - from megatron.core import parallel_state as mpu - from megatron.core.distributed.distributed_data_parallel_config import DistributedDataParallelConfig - from megatron.core.transformer.transformer_config import TransformerConfig - MEGATRON_AVAILABLE = True -except ImportError: - MEGATRON_AVAILABLE = False - -try: - from peft.tuners.lora import LoraLayer, LoraModel - PEFT_AVAILABLE = True -except ImportError: - PEFT_AVAILABLE = False - - -class MegatronMultiAdapter: - """ - Patches LoRA layers to use ContextVar-based adapter selection. - - This enables thread-safe multi-tenant training where each tenant's - active adapter is determined by the current context. - """ - - _adapter_var: contextvars.ContextVar[ - Optional[str]] = contextvars.ContextVar('adapter_names', default=None) - _patched: bool = False - - def __call__(self, module: nn.Module) -> nn.Module: - """Patch LoRA layers.""" - if MegatronMultiAdapter._patched: - return module - - self._patch_peft_lora() - self._patch_twinkle_lora() - - module.set_current_adapter_name = MegatronMultiAdapter.set_current_adapter_name - MegatronMultiAdapter._patched = True - - return module - - def _patch_peft_lora(self): - """Patch PEFT's LoraLayer/LoraModel.""" - if not PEFT_AVAILABLE: - return - - if getattr(LoraLayer, '_patched', False): - return - - def get_active_adapter(*args): - return MegatronMultiAdapter._adapter_var.get() - - def get_active_adapters(*args): - adapter = MegatronMultiAdapter._adapter_var.get() - return [adapter] if adapter else [] - - LoraLayer.active_adapter = property(get_active_adapter) - LoraLayer.active_adapters = property(get_active_adapters) - LoraLayer.set_adapter = lambda self, x: None - LoraLayer._patched = True - - LoraModel.active_adapter = property(get_active_adapter) - LoraModel.active_adapters = property(get_active_adapters) - LoraModel.set_adapter = lambda self, x: None - LoraModel._patched = True - - logger.info('Patched PEFT LoraLayer/LoraModel') - - def _patch_twinkle_lora(self): - """Patch Twinkle's LoraParallelLinear.""" - try: - from twinkle.megatron.tuners.lora import LoraParallelLinear - if hasattr(LoraParallelLinear, '_patched'): - return - - def get_active_adapter(self): - return MegatronMultiAdapter._adapter_var.get() - - def get_active_adapters(self): - adapter = MegatronMultiAdapter._adapter_var.get() - return [adapter] if adapter else [] - - LoraParallelLinear.active_adapter = property(get_active_adapter) - LoraParallelLinear.active_adapters = property(get_active_adapters) - LoraParallelLinear._patched = True - logger.info('Patched LoraParallelLinear') - except ImportError: - pass - - @staticmethod - def set_current_adapter_name(name: Optional[str]): - """Set current adapter.""" - MegatronMultiAdapter._adapter_var.set(name) - - @staticmethod - def get_current_adapter_name() -> Optional[str]: - """Get current adapter.""" - return MegatronMultiAdapter._adapter_var.get() - - -class MultiTenantMegatronModel(nn.Module): - """ - Multi-tenant Megatron model wrapper. - - Combines: - - TenantManager: Tenant lifecycle (adapters, optimizers) - - MultiTenantLoRADDP: Per-tenant gradient sync - - MegatronMultiAdapter: Context-based adapter selection - - Example: - >>> model = MultiTenantMegatronModel(base_model, config) - >>> - >>> # Initialize tenant (creates adapter, buffers, optimizer) - >>> tenant_id = model.initialize(lora_config=LoraConfig(r=8)) - >>> - >>> # Training (uses current tenant automatically) - >>> model.zero_grad() - >>> output = model(input) - >>> loss = compute_loss(output) - >>> model.backward(loss) - >>> model.finish_grad_sync() - >>> model.step() - >>> - >>> # Cleanup - >>> model.finalize() - """ - def __init__( - self, - model: nn.Module, - config: 'TransformerConfig', - ddp_config: Optional['DistributedDataParallelConfig'] = None, - ): - """ - Initialize. - - Args: - model: Base model with LoRA structure. - config: Transformer config. - ddp_config: DDP config. - """ - super().__init__() - - if not MEGATRON_AVAILABLE: - raise ImportError('Megatron-Core required') - - self.config = config - self.ddp_config = ddp_config or DistributedDataParallelConfig( - overlap_grad_reduce=True, - use_distributed_optimizer=False, - ) - - # Patch LoRA layers for multi-tenant - self._multi_adapter = MegatronMultiAdapter() - self.model = self._multi_adapter(model) - - # Create DDP - self._ddp = MultiTenantLoRADDP( - config=self.config, - ddp_config=self.ddp_config, - module=self.model, - ) - - # Create tenant manager - self._manager = TenantManager( - model=self.model, - default_process_group=mpu.get_data_parallel_group( - with_context_parallel=True), - ) - - # Wire up callbacks - self._manager.register_add_callback(self._on_tenant_added) - self._manager.register_remove_callback(self._on_tenant_removed) - - logger.info('MultiTenantMegatronModel initialized') - - def _on_tenant_added(self, state: TenantState): - """Called when tenant is added via manager.""" - self._ddp.add_tenant( - tenant_id=state.tenant_id, - params=state.params, - process_group=state.process_group, - param_names=state.param_names, - ) - - def _on_tenant_removed(self, state: TenantState): - """Called when tenant is removed via manager.""" - if self._ddp.has_tenant(state.tenant_id): - self._ddp.remove_tenant(state.tenant_id) - - def forward(self, *args, **kwargs): - """Forward pass.""" - return self._ddp(*args, **kwargs) - - # ========== Tenant Lifecycle ========== - - def initialize(self, **kwargs) -> str: - """ - Initialize a tenant. - - Args: - **kwargs: Passed to TenantManager.initialize() - - Returns: - Tenant ID. - """ - return self._manager.initialize(**kwargs) - - def finalize(self, tenant_id: Optional[str] = None): - """Finalize a tenant.""" - self._manager.finalize(tenant_id) - - @contextmanager - def scope(self, tenant_id: Optional[str] = None): - """Context manager for tenant scope.""" - with self._manager.scope(tenant_id) as state: - # Also set adapter - MegatronMultiAdapter.set_current_adapter_name(state.adapter_name) - try: - yield state - finally: - MegatronMultiAdapter.set_current_adapter_name(None) - - # ========== Training Operations ========== - - def zero_grad(self, tenant_id: Optional[str] = None): - """Zero gradients.""" - tenant_id = tenant_id or require_tenant() - state = self._manager.get(tenant_id) - - self._ddp.zero_grad_buffer(tenant_id) - if state.optimizer: - state.optimizer.zero_grad(set_to_none=True) - - def backward(self, loss: torch.Tensor, tenant_id: Optional[str] = None): - """Backward pass.""" - tenant_id = tenant_id or require_tenant() - state = self._manager.get(tenant_id) - - MegatronMultiAdapter.set_current_adapter_name(state.adapter_name) - scaled_loss = loss / state.gradient_accumulation_steps - scaled_loss.backward() - - @contextmanager - def no_sync(self, tenant_id: Optional[str] = None): - """Disable gradient sync.""" - with self._ddp.no_sync(tenant_id): - yield - - def finish_grad_sync(self, tenant_id: Optional[str] = None): - """Finish gradient sync.""" - self._ddp.finish_grad_sync(tenant_id) - - def clip_grad_norm( - self, - max_norm: Optional[float] = None, - tenant_id: Optional[str] = None, - ) -> torch.Tensor: - """Clip gradients.""" - tenant_id = tenant_id or require_tenant() - state = self._manager.get(tenant_id) - max_norm = max_norm or state.max_grad_norm - return torch.nn.utils.clip_grad_norm_(state.params, max_norm) - - def step(self, tenant_id: Optional[str] = None): - """Optimizer step.""" - tenant_id = tenant_id or require_tenant() - state = self._manager.get(tenant_id) - if state.optimizer: - state.optimizer.step() - - def lr_step(self, tenant_id: Optional[str] = None): - """LR scheduler step.""" - tenant_id = tenant_id or require_tenant() - state = self._manager.get(tenant_id) - if state.scheduler: - state.scheduler.step() - - def get_lr(self, tenant_id: Optional[str] = None) -> Optional[float]: - """Get current LR.""" - tenant_id = tenant_id or require_tenant() - state = self._manager.get(tenant_id) - if state.optimizer: - return state.optimizer.param_groups[0]['lr'] - return None - - # ========== Utilities ========== - - def tenant_count(self) -> int: - """Get number of active tenants.""" - return self._manager.count() - - def has_tenant(self, tenant_id: str) -> bool: - """Check if a specific tenant exists.""" - return self._manager.has(tenant_id) - - @property - def ddp(self) -> MultiTenantLoRADDP: - """Get DDP wrapper.""" - return self._ddp - - @property - def manager(self) -> TenantManager: - """Get tenant manager.""" - return self._manager - - @property - def unwrapped_model(self) -> nn.Module: - """Get unwrapped model.""" - return self.model diff --git a/tests/megatron/test_multi_tenant_benchmark.py b/tests/megatron/test_multi_tenant_benchmark.py deleted file mode 100644 index b9dcb3f8..00000000 --- a/tests/megatron/test_multi_tenant_benchmark.py +++ /dev/null @@ -1,746 +0,0 @@ -#!/usr/bin/env python -""" -Benchmark comparison of multi-tenant architectures. - -Compares: -1. Twinkle Mode: Per-tenant serial execution (independent calls) -2. Clock Cycle Mode: Unified scheduling + batched communication - -Key insight: For LLM+LoRA, batch merging is NOT possible because LoRA -weights are embedded in every layer. The benefit of Clock Cycle is -batched communication, not merged computation. - -Metrics: -- Throughput (samples/second) -- Latency (per step) -- Communication efficiency (N syncs vs 1 sync) -""" - -import argparse -import logging -import threading -import time -from concurrent.futures import ThreadPoolExecutor, as_completed -from contextlib import contextmanager -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple - -import torch -import torch.nn as nn - -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - - -# ============ Mock Model with Tinker-compatible API ============ - -class MockBaseModel(nn.Module): - """Mock base model (shared across tenants).""" - - def __init__(self, hidden_size: int, num_layers: int, simulate_ms: float): - super().__init__() - self.hidden_size = hidden_size - self.num_layers = num_layers - self.simulate_ms = simulate_ms - - # Create base layers (frozen) - self.layers = nn.ModuleList([ - nn.Linear(hidden_size, hidden_size, bias=False) - for _ in range(num_layers) - ]) - - for layer in self.layers: - layer.weight.requires_grad = False - - # Stats - self.forward_calls = 0 - self.total_samples = 0 - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward pass - simulates compute time.""" - batch_size = x.size(0) - self.forward_calls += 1 - self.total_samples += batch_size - - # Simulate compute (scales slightly with batch size) - # Key insight: one large batch is more efficient than N small batches - time.sleep(self.simulate_ms / 1000.0 * (1 + 0.1 * (batch_size / 8))) - - for layer in self.layers: - x = layer(x) - x = torch.relu(x) - return x - - def reset_stats(self): - self.forward_calls = 0 - self.total_samples = 0 - - -class MockLoRAAdapter(nn.Module): - """Mock LoRA adapter for a single tenant.""" - - def __init__(self, hidden_size: int, rank: int = 8, simulate_ms: float = 1.0): - super().__init__() - self.hidden_size = hidden_size - self.rank = rank - self.simulate_ms = simulate_ms - - self.lora_A = nn.Parameter(torch.randn(rank, hidden_size) * 0.01) - self.lora_B = nn.Parameter(torch.zeros(hidden_size, rank)) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Apply LoRA transformation.""" - time.sleep(self.simulate_ms / 1000.0) - return x + x @ self.lora_A.T @ self.lora_B.T - - -class MockMultiTenantModel(nn.Module): - """ - Mock multi-tenant model with Tinker-compatible API. - - Supports: - - base_forward(): Run base model only (for batch merging) - - apply_lora(): Apply per-tenant LoRA - - scope(): Context manager for tenant selection - - finish_grad_sync_batched(): Batched gradient sync - """ - - def __init__( - self, - hidden_size: int = 256, - num_layers: int = 4, - lora_rank: int = 8, - base_model_ms: float = 10.0, - lora_ms: float = 2.0, - comm_ms: float = 5.0, - ): - super().__init__() - self.hidden_size = hidden_size - self.num_layers = num_layers - self.lora_rank = lora_rank - self.base_model_ms = base_model_ms - self.lora_ms = lora_ms - self.comm_ms = comm_ms - - # Base model (shared) - self.base_model = MockBaseModel(hidden_size, num_layers, base_model_ms) - - # Per-tenant adapters - self._adapters: Dict[str, MockLoRAAdapter] = nn.ModuleDict() - self._optimizers: Dict[str, torch.optim.Optimizer] = {} - - # Current tenant context - self._current_tenant: Optional[str] = None - self._lock = threading.Lock() - - # Stats - self._compute_time = 0.0 - self._comm_time = 0.0 - - def initialize( - self, - tenant_id: Optional[str] = None, - optimizer_kwargs: Optional[Dict] = None, - **kwargs, - ) -> str: - """Initialize a tenant.""" - import uuid - tenant_id = tenant_id or str(uuid.uuid4())[:8] - - with self._lock: - if tenant_id in self._adapters: - raise ValueError(f"Tenant {tenant_id} exists") - - # Create adapter - adapter = MockLoRAAdapter(self.hidden_size, self.lora_rank, self.lora_ms) - self._adapters[tenant_id] = adapter - - # Create optimizer - opt_kwargs = optimizer_kwargs or {'lr': 1e-4} - self._optimizers[tenant_id] = torch.optim.AdamW(adapter.parameters(), **opt_kwargs) - - self._current_tenant = tenant_id - - return tenant_id - - def finalize(self, tenant_id: Optional[str] = None): - """Finalize a tenant.""" - tenant_id = tenant_id or self._current_tenant - with self._lock: - if tenant_id in self._adapters: - del self._adapters[tenant_id] - del self._optimizers[tenant_id] - - @contextmanager - def scope(self, tenant_id: str): - """Context manager for tenant scope.""" - old = self._current_tenant - self._current_tenant = tenant_id - try: - yield - finally: - self._current_tenant = old - - def base_forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Run ONLY the base model (for batch merging). - - This is the key for Tinker efficiency - call once for all tenants. - """ - return self.base_model(x) - - def apply_lora(self, features: torch.Tensor, tenant_id: Optional[str] = None) -> torch.Tensor: - """Apply per-tenant LoRA to pre-computed features.""" - tenant_id = tenant_id or self._current_tenant - if tenant_id and tenant_id in self._adapters: - return self._adapters[tenant_id](features) - return features - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Full forward pass (base + current tenant's LoRA).""" - features = self.base_model(x) - return self.apply_lora(features) - - def backward(self, loss: torch.Tensor, tenant_id: Optional[str] = None): - """Backward pass.""" - t0 = time.time() - loss.backward() - self._compute_time += time.time() - t0 - - def finish_grad_sync(self, tenant_id: Optional[str] = None): - """Finish gradient sync for single tenant.""" - time.sleep(self.comm_ms / 1000.0) - self._comm_time += self.comm_ms / 1000.0 - - def finish_grad_sync_batched(self, tenant_ids: List[str]): - """ - Batched gradient sync (Tinker optimization). - - One all-reduce for all tenants instead of N all-reduces. - """ - # Simulate batched communication (more efficient than N separate calls) - # Overhead is sub-linear with number of tenants - batched_time = self.comm_ms / 1000.0 * (1 + 0.1 * len(tenant_ids)) - time.sleep(batched_time) - self._comm_time += batched_time - - def clip_grad_norm(self, tenant_id: Optional[str] = None, max_norm: float = 1.0): - """Clip gradients.""" - tenant_id = tenant_id or self._current_tenant - if tenant_id in self._adapters: - torch.nn.utils.clip_grad_norm_( - self._adapters[tenant_id].parameters(), max_norm - ) - - def step(self, tenant_id: Optional[str] = None): - """Optimizer step.""" - tenant_id = tenant_id or self._current_tenant - if tenant_id in self._optimizers: - self._optimizers[tenant_id].step() - - def zero_grad(self, tenant_id: Optional[str] = None): - """Zero gradients.""" - tenant_id = tenant_id or self._current_tenant - if tenant_id in self._optimizers: - self._optimizers[tenant_id].zero_grad(set_to_none=True) - - def get_stats(self) -> Dict[str, Any]: - return { - 'compute_time': self._compute_time, - 'comm_time': self._comm_time, - 'base_model_forward_calls': self.base_model.forward_calls, - 'base_model_total_samples': self.base_model.total_samples, - } - - def reset_stats(self): - self._compute_time = 0.0 - self._comm_time = 0.0 - self.base_model.reset_stats() - - def tenant_count(self) -> int: - return len(self._adapters) - - def has_tenant(self, tenant_id: str) -> bool: - return tenant_id in self._adapters - - -# ============ Benchmark Classes ============ - -@dataclass -class BenchmarkConfig: - """Configuration for benchmark.""" - num_tenants: int = 4 - steps_per_tenant: int = 10 - batch_size_per_tenant: int = 8 - hidden_size: int = 256 - base_model_ms: float = 10.0 # Base model forward time - lora_ms: float = 2.0 # LoRA forward time per tenant - comm_ms: float = 5.0 # Communication time - clock_cycle_interval_ms: float = 50.0 - - -@dataclass -class BenchmarkResult: - """Result of a benchmark run.""" - mode: str - total_time: float - total_steps: int - total_samples: int - throughput_steps: float # steps/second - throughput_samples: float # samples/second - avg_latency: float # seconds per step - base_model_calls: int # Number of base model forward calls - base_model_samples: int # Total samples processed by base model - compute_time: float - comm_time: float - gpu_utilization: float # compute_time / total_time - - def __str__(self): - return ( - f"{self.mode}:\n" - f" Total time: {self.total_time:.2f}s\n" - f" Total steps: {self.total_steps} ({self.total_samples} samples)\n" - f" Throughput: {self.throughput_steps:.2f} steps/s, {self.throughput_samples:.2f} samples/s\n" - f" Avg latency: {self.avg_latency*1000:.2f} ms/step\n" - f" Base model calls: {self.base_model_calls} (samples: {self.base_model_samples})\n" - f" GPU utilization: {self.gpu_utilization*100:.1f}%\n" - ) - - -class TwinkleBenchmark: - """ - Benchmark for Twinkle mode (per-tenant serial execution). - - In this mode: - - Each tenant's request is processed separately - - Base model is called N times (once per tenant) - - Gradient sync is done N times - """ - - def __init__(self, config: BenchmarkConfig): - self.config = config - self.model = MockMultiTenantModel( - hidden_size=config.hidden_size, - base_model_ms=config.base_model_ms, - lora_ms=config.lora_ms, - comm_ms=config.comm_ms, - ) - - def run(self) -> BenchmarkResult: - """Run the benchmark.""" - logger.info("Running Twinkle mode benchmark...") - - # Initialize tenants - tenant_ids = [] - for i in range(self.config.num_tenants): - tid = self.model.initialize(tenant_id=f"tenant_{i}") - tenant_ids.append(tid) - - self.model.reset_stats() - - # Create dummy input - x = torch.randn(self.config.batch_size_per_tenant, self.config.hidden_size) - - total_steps = 0 - total_samples = 0 - step_latencies = [] - - start_time = time.time() - - # Training loop - serial per tenant - for step in range(self.config.steps_per_tenant): - for tenant_id in tenant_ids: - step_start = time.time() - - with self.model.scope(tenant_id): - self.model.zero_grad(tenant_id) - output = self.model(x) # Full forward (base + LoRA) - loss = output.mean() - self.model.backward(loss, tenant_id) - self.model.finish_grad_sync(tenant_id) # Individual sync - self.model.clip_grad_norm(tenant_id) - self.model.step(tenant_id) - - step_latencies.append(time.time() - step_start) - total_steps += 1 - total_samples += self.config.batch_size_per_tenant - - total_time = time.time() - start_time - - # Cleanup - for tid in tenant_ids: - self.model.finalize(tid) - - # Calculate metrics - stats = self.model.get_stats() - - return BenchmarkResult( - mode="Twinkle (Serial)", - total_time=total_time, - total_steps=total_steps, - total_samples=total_samples, - throughput_steps=total_steps / total_time, - throughput_samples=total_samples / total_time, - avg_latency=sum(step_latencies) / len(step_latencies), - base_model_calls=stats['base_model_forward_calls'], - base_model_samples=stats['base_model_total_samples'], - compute_time=stats['compute_time'], - comm_time=stats['comm_time'], - gpu_utilization=stats['compute_time'] / total_time, - ) - - -class TinkerBenchmark: - """ - Benchmark for Tinker mode (clock cycle with batch merging). - - In this mode: - - Multiple tenants' requests are batched in each cycle - - Base model is called ONCE per cycle (with merged batch) - - Gradient sync is done ONCE per cycle (batched) - """ - - def __init__(self, config: BenchmarkConfig): - self.config = config - self.model = MockMultiTenantModel( - hidden_size=config.hidden_size, - base_model_ms=config.base_model_ms, - lora_ms=config.lora_ms, - comm_ms=config.comm_ms, - ) - - def run(self) -> BenchmarkResult: - """Run the benchmark.""" - from twinkle.megatron.distributed.clock_cycle_scheduler import ( - ClockCycleScheduler, - ClockCycleTrainingClient, - ) - - logger.info("Running Tinker mode benchmark...") - - # Initialize tenants - tenant_ids = [] - for i in range(self.config.num_tenants): - tid = self.model.initialize(tenant_id=f"tenant_{i}") - tenant_ids.append(tid) - - self.model.reset_stats() - - # Create scheduler - scheduler = ClockCycleScheduler( - model=self.model, - cycle_interval_ms=self.config.clock_cycle_interval_ms, - ) - scheduler.start() - - # Create clients for each tenant - clients = { - tid: ClockCycleTrainingClient(scheduler, tid) - for tid in tenant_ids - } - - total_steps = 0 - total_samples = 0 - step_latencies = [] - - start_time = time.time() - - # Training loop - all tenants submit concurrently - def tenant_worker(tenant_id: str, client: ClockCycleTrainingClient): - nonlocal total_steps, total_samples - latencies = [] - - # Each tenant has its own batch - x = torch.randn(self.config.batch_size_per_tenant, self.config.hidden_size) - - for step in range(self.config.steps_per_tenant): - step_start = time.time() - - # Submit forward-backward and optimizer step - result = client.train_step(x) - - latencies.append(time.time() - step_start) - total_steps += 1 - total_samples += self.config.batch_size_per_tenant - - return latencies - - # Run all tenants concurrently - with ThreadPoolExecutor(max_workers=self.config.num_tenants) as executor: - futures = { - executor.submit(tenant_worker, tid, clients[tid]): tid - for tid in tenant_ids - } - - for future in as_completed(futures): - try: - latencies = future.result() - step_latencies.extend(latencies) - except Exception as e: - logger.error(f"Tenant worker failed: {e}") - - total_time = time.time() - start_time - - # Stop scheduler - scheduler.stop() - - # Get scheduler stats - sched_stats = scheduler.get_summary_stats() - - # Cleanup - for tid in tenant_ids: - self.model.finalize(tid) - - # Calculate metrics - model_stats = self.model.get_stats() - - return BenchmarkResult( - mode="Tinker (Clock Cycle)", - total_time=total_time, - total_steps=total_steps, - total_samples=total_samples, - throughput_steps=total_steps / total_time, - throughput_samples=total_samples / total_time, - avg_latency=sum(step_latencies) / len(step_latencies) if step_latencies else 0, - base_model_calls=model_stats['base_model_forward_calls'], - base_model_samples=model_stats['base_model_total_samples'], - compute_time=model_stats['compute_time'], - comm_time=model_stats['comm_time'], - gpu_utilization=sched_stats.get('gpu_utilization', 0), - ) - - -# ============ Test Functions ============ - -def test_twinkle_mode(): - """Test Twinkle mode functionality.""" - logger.info("Testing Twinkle mode...") - - model = MockMultiTenantModel(base_model_ms=1.0, lora_ms=0.5, comm_ms=0.5) - - # Initialize 2 tenants - tid1 = model.initialize(tenant_id="test_1") - tid2 = model.initialize(tenant_id="test_2") - - assert model.tenant_count() == 2 - assert model.has_tenant(tid1) - assert model.has_tenant(tid2) - - # Training step for tenant 1 - x = torch.randn(4, 256) - with model.scope(tid1): - model.zero_grad(tid1) - output = model(x) - loss = output.mean() - model.backward(loss, tid1) - model.finish_grad_sync(tid1) - model.step(tid1) - - # Training step for tenant 2 - with model.scope(tid2): - model.zero_grad(tid2) - output = model(x) - loss = output.mean() - model.backward(loss, tid2) - model.finish_grad_sync(tid2) - model.step(tid2) - - # Verify base model was called twice - stats = model.get_stats() - assert stats['base_model_forward_calls'] == 2, f"Expected 2 calls, got {stats['base_model_forward_calls']}" - - # Cleanup - model.finalize(tid1) - model.finalize(tid2) - - assert model.tenant_count() == 0 - - logger.info("Twinkle mode test PASSED") - return True - - -def test_tinker_mode(): - """Test Tinker mode functionality.""" - from twinkle.megatron.distributed.clock_cycle_scheduler import ( - ClockCycleScheduler, - ClockCycleTrainingClient, - ) - - logger.info("Testing Tinker mode...") - - model = MockMultiTenantModel(base_model_ms=1.0, lora_ms=0.5, comm_ms=0.5) - - # Initialize 2 tenants - tid1 = model.initialize(tenant_id="test_1") - tid2 = model.initialize(tenant_id="test_2") - - # Create scheduler - scheduler = ClockCycleScheduler(model, cycle_interval_ms=50.0) - scheduler.start() - - # Create clients - client1 = ClockCycleTrainingClient(scheduler, tid1) - client2 = ClockCycleTrainingClient(scheduler, tid2) - - x1 = torch.randn(4, 256) - x2 = torch.randn(4, 256) - - # Both tenants submit requests (should be in same cycle) - future1 = client1.forward_backward(x1) - future2 = client2.forward_backward(x2) - - opt1 = client1.optim_step() - opt2 = client2.optim_step() - - # Wait for results - result1 = future1.result(timeout=5.0) - result2 = future2.result(timeout=5.0) - opt1.result(timeout=5.0) - opt2.result(timeout=5.0) - - assert 'loss' in result1 or 'error' in result1, f"Unexpected result: {result1}" - assert 'loss' in result2 or 'error' in result2, f"Unexpected result: {result2}" - - # Check they were in same cycle - if 'cycle_id' in result1 and 'cycle_id' in result2: - logger.info(f"Cycle IDs: {result1['cycle_id']}, {result2['cycle_id']}") - - # Stop scheduler - scheduler.stop() - - # Check stats - stats = scheduler.get_summary_stats() - logger.info(f"Scheduler stats: {stats}") - - # Cleanup - model.finalize(tid1) - model.finalize(tid2) - - logger.info("Tinker mode test PASSED") - return True - - -def test_batch_merging(): - """Test that batch merging works correctly.""" - from twinkle.megatron.distributed.clock_cycle_scheduler import BatchBuilder, TrainingRequest, RequestType - - logger.info("Testing batch merging...") - - builder = BatchBuilder() - - # Create requests from 3 tenants - requests = { - 'tenant_a': TrainingRequest( - tenant_id='tenant_a', - request_type=RequestType.FORWARD_BACKWARD, - inputs=torch.randn(4, 256), - ), - 'tenant_b': TrainingRequest( - tenant_id='tenant_b', - request_type=RequestType.FORWARD_BACKWARD, - inputs=torch.randn(8, 256), - ), - 'tenant_c': TrainingRequest( - tenant_id='tenant_c', - request_type=RequestType.FORWARD_BACKWARD, - inputs=torch.randn(2, 256), - ), - } - - # Build merged batch - merged = builder.build(requests) - - # Verify - assert merged.total_size == 14, f"Expected 14, got {merged.total_size}" - assert merged.merged_inputs.shape == (14, 256), f"Wrong shape: {merged.merged_inputs.shape}" - assert merged.tenant_slices['tenant_a'] == (0, 4) - assert merged.tenant_slices['tenant_b'] == (4, 12) - assert merged.tenant_slices['tenant_c'] == (12, 14) - - logger.info("Batch merging test PASSED") - return True - - -def run_benchmark_comparison(config: BenchmarkConfig): - """Run and compare both benchmarks.""" - print("") - print("=" * 60) - print("Multi-Tenant Architecture Benchmark") - print("=" * 60) - print(f"Config: {config}") - print("") - - # Run Twinkle benchmark - twinkle = TwinkleBenchmark(config) - twinkle_result = twinkle.run() - - # Run Tinker benchmark - tinker = TinkerBenchmark(config) - tinker_result = tinker.run() - - # Print results - print("") - print("=" * 60) - print("Results") - print("=" * 60) - print(twinkle_result) - print(tinker_result) - - # Comparison - print("=" * 60) - print("Comparison") - print("=" * 60) - - # Throughput - speedup = tinker_result.throughput_samples / twinkle_result.throughput_samples - print(f"Throughput speedup (Tinker/Twinkle): {speedup:.2f}x") - - # Base model efficiency - base_model_ratio = twinkle_result.base_model_calls / max(tinker_result.base_model_calls, 1) - print(f"Base model calls: {twinkle_result.base_model_calls} vs {tinker_result.base_model_calls} ({base_model_ratio:.1f}x fewer)") - - # Latency - latency_diff = (twinkle_result.avg_latency - tinker_result.avg_latency) / twinkle_result.avg_latency * 100 - print(f"Latency improvement: {latency_diff:.1f}%") - - # GPU utilization - gpu_diff = (tinker_result.gpu_utilization - twinkle_result.gpu_utilization) * 100 - print(f"GPU utilization difference: {gpu_diff:+.1f}%") - - return twinkle_result, tinker_result - - -def main(): - parser = argparse.ArgumentParser(description="Multi-tenant benchmark") - parser.add_argument("--num-tenants", type=int, default=4) - parser.add_argument("--steps", type=int, default=10) - parser.add_argument("--batch-size", type=int, default=8) - parser.add_argument("--base-model-ms", type=float, default=10.0) - parser.add_argument("--lora-ms", type=float, default=2.0) - parser.add_argument("--comm-ms", type=float, default=5.0) - parser.add_argument("--cycle-ms", type=float, default=50.0) - parser.add_argument("--test-only", action="store_true", help="Run tests only") - args = parser.parse_args() - - if args.test_only: - test_twinkle_mode() - test_batch_merging() - test_tinker_mode() - logger.info("All tests passed!") - return - - config = BenchmarkConfig( - num_tenants=args.num_tenants, - steps_per_tenant=args.steps, - batch_size_per_tenant=args.batch_size, - base_model_ms=args.base_model_ms, - lora_ms=args.lora_ms, - comm_ms=args.comm_ms, - clock_cycle_interval_ms=args.cycle_ms, - ) - - run_benchmark_comparison(config) - - -if __name__ == "__main__": - main() diff --git a/tests/megatron/test_multi_tenant_modules.py b/tests/megatron/test_multi_tenant_modules.py deleted file mode 100644 index b2c3b09e..00000000 --- a/tests/megatron/test_multi_tenant_modules.py +++ /dev/null @@ -1,641 +0,0 @@ -#!/usr/bin/env python -""" -Unit tests for multi-tenant modules. - -Tests: -1. tenant_context.py - ContextVar-based tenant management -2. tenant_manager.py - Tenant lifecycle management (WIP, requires PEFT) -3. multi_tenant_ddp.py - Multi-tenant DDP (requires Megatron) -4. clock_cycle_scheduler.py - Clock cycle scheduling - -Each module is tested independently with mock dependencies where needed. -""" - -import logging -import threading -import time -import unittest -from concurrent.futures import ThreadPoolExecutor, as_completed -from contextlib import contextmanager -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional -from unittest.mock import MagicMock, patch - -import torch -import torch.nn as nn - -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - - -# ============================================================================ -# Test 1: tenant_context.py -# ============================================================================ - -class TestTenantContext(unittest.TestCase): - """Tests for tenant_context module.""" - - def setUp(self): - """Reset tenant context before each test.""" - from twinkle.megatron.distributed.tenant_context import set_current_tenant - set_current_tenant(None) - - def test_get_set_tenant(self): - """Test basic get/set tenant.""" - from twinkle.megatron.distributed.tenant_context import ( - get_current_tenant, set_current_tenant - ) - - # Initially None - self.assertIsNone(get_current_tenant()) - - # Set tenant - set_current_tenant('tenant_a') - self.assertEqual(get_current_tenant(), 'tenant_a') - - # Change tenant - set_current_tenant('tenant_b') - self.assertEqual(get_current_tenant(), 'tenant_b') - - # Clear tenant - set_current_tenant(None) - self.assertIsNone(get_current_tenant()) - - def test_require_tenant(self): - """Test require_tenant raises error when not set.""" - from twinkle.megatron.distributed.tenant_context import ( - require_tenant, set_current_tenant - ) - - # Should raise when not set - with self.assertRaises(RuntimeError): - require_tenant() - - # Should return when set - set_current_tenant('tenant_a') - self.assertEqual(require_tenant(), 'tenant_a') - - def test_tenant_scope(self): - """Test tenant_scope context manager.""" - from twinkle.megatron.distributed.tenant_context import ( - get_current_tenant, tenant_scope, set_current_tenant - ) - - set_current_tenant('outer') - - with tenant_scope('inner'): - self.assertEqual(get_current_tenant(), 'inner') - - # Should restore after context - self.assertEqual(get_current_tenant(), 'outer') - - def test_nested_scopes(self): - """Test nested tenant scopes.""" - from twinkle.megatron.distributed.tenant_context import ( - get_current_tenant, tenant_scope - ) - - with tenant_scope('a'): - self.assertEqual(get_current_tenant(), 'a') - - with tenant_scope('b'): - self.assertEqual(get_current_tenant(), 'b') - - with tenant_scope('c'): - self.assertEqual(get_current_tenant(), 'c') - - self.assertEqual(get_current_tenant(), 'b') - - self.assertEqual(get_current_tenant(), 'a') - - def test_thread_isolation(self): - """Test that tenant context is isolated between threads.""" - from twinkle.megatron.distributed.tenant_context import ( - get_current_tenant, set_current_tenant - ) - - results = {} - - def thread_func(tenant_id: str, delay: float): - set_current_tenant(tenant_id) - time.sleep(delay) - results[tenant_id] = get_current_tenant() - - # Run multiple threads - threads = [ - threading.Thread(target=thread_func, args=('thread_a', 0.1)), - threading.Thread(target=thread_func, args=('thread_b', 0.05)), - threading.Thread(target=thread_func, args=('thread_c', 0.15)), - ] - - for t in threads: - t.start() - for t in threads: - t.join() - - # Each thread should have its own context - self.assertEqual(results['thread_a'], 'thread_a') - self.assertEqual(results['thread_b'], 'thread_b') - self.assertEqual(results['thread_c'], 'thread_c') - - def test_generate_tenant_id(self): - """Test tenant ID generation.""" - from twinkle.megatron.distributed.tenant_context import generate_tenant_id - - id1 = generate_tenant_id() - id2 = generate_tenant_id() - - # Should be unique - self.assertNotEqual(id1, id2) - - # Should be 8 chars - self.assertEqual(len(id1), 8) - self.assertEqual(len(id2), 8) - - def test_with_tenant_context_decorator(self): - """Test @with_tenant_context decorator.""" - from twinkle.megatron.distributed.tenant_context import ( - with_tenant_context, tenant_scope - ) - - @with_tenant_context - def example_func(tenant_id: Optional[str] = None): - return tenant_id - - # Should use context when tenant_id not provided - with tenant_scope('context_tenant'): - result = example_func() - self.assertEqual(result, 'context_tenant') - - # Should use explicit tenant_id when provided - with tenant_scope('context_tenant'): - result = example_func(tenant_id='explicit_tenant') - self.assertEqual(result, 'explicit_tenant') - - -# ============================================================================ -# Test 2: clock_cycle_scheduler.py -# ============================================================================ - -class MockMultiTenantModel(nn.Module): - """Mock model that implements the required interface for ClockCycleScheduler.""" - - def __init__(self, hidden_size: int = 64, simulate_ms: float = 1.0): - super().__init__() - self.hidden_size = hidden_size - self.simulate_ms = simulate_ms - - # Base layer (frozen) - self.base = nn.Linear(hidden_size, hidden_size) - self.base.weight.requires_grad = False - - # Per-tenant adapters - self._adapters: Dict[str, nn.Module] = {} - self._optimizers: Dict[str, torch.optim.Optimizer] = {} - self._current_tenant: Optional[str] = None - - def add_tenant(self, tenant_id: str) -> None: - """Add a tenant with LoRA adapter.""" - adapter = nn.Linear(self.hidden_size, self.hidden_size) - self._adapters[tenant_id] = adapter - self._optimizers[tenant_id] = torch.optim.SGD(adapter.parameters(), lr=0.01) - - def remove_tenant(self, tenant_id: str) -> None: - """Remove a tenant.""" - if tenant_id in self._adapters: - del self._adapters[tenant_id] - del self._optimizers[tenant_id] - - @contextmanager - def scope(self, tenant_id: str): - """Context manager for tenant scope.""" - old = self._current_tenant - self._current_tenant = tenant_id - try: - yield - finally: - self._current_tenant = old - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward pass using current tenant's adapter.""" - time.sleep(self.simulate_ms / 1000.0) - - out = self.base(x) - if self._current_tenant and self._current_tenant in self._adapters: - out = out + self._adapters[self._current_tenant](x) - return out - - def zero_grad(self, tenant_id: str) -> None: - """Zero gradients for tenant.""" - if tenant_id in self._optimizers: - self._optimizers[tenant_id].zero_grad(set_to_none=True) - - def step(self, tenant_id: str) -> None: - """Optimizer step for tenant.""" - if tenant_id in self._optimizers: - self._optimizers[tenant_id].step() - - def clip_grad_norm(self, tenant_id: str, max_norm: float = 1.0) -> None: - """Clip gradients for tenant.""" - if tenant_id in self._adapters: - torch.nn.utils.clip_grad_norm_( - self._adapters[tenant_id].parameters(), max_norm - ) - - def finish_grad_sync(self, tenant_id: str) -> None: - """Gradient sync for single tenant (no-op in non-distributed).""" - pass - - def finish_grad_sync_batched(self, tenant_ids: List[str]) -> None: - """Batched gradient sync (no-op in non-distributed).""" - pass - - -class TestClockCycleScheduler(unittest.TestCase): - """Tests for clock_cycle_scheduler module.""" - - def test_model_interface_validation(self): - """Test that scheduler validates model interface.""" - from twinkle.megatron.distributed.clock_cycle_scheduler import ( - ClockCycleScheduler, ModelInterfaceError - ) - - # Model without required methods should fail - bad_model = nn.Linear(10, 10) - with self.assertRaises(ModelInterfaceError): - ClockCycleScheduler(bad_model) - - # Good model should work - good_model = MockMultiTenantModel() - scheduler = ClockCycleScheduler(good_model, cycle_interval_ms=10) - self.assertIsNotNone(scheduler) - - def test_basic_training_step(self): - """Test basic training step through scheduler.""" - from twinkle.megatron.distributed.clock_cycle_scheduler import ( - ClockCycleScheduler, ClockCycleTrainingClient - ) - - model = MockMultiTenantModel(simulate_ms=0.5) - model.add_tenant('tenant_a') - - scheduler = ClockCycleScheduler(model, cycle_interval_ms=20) - scheduler.start() - - try: - client = ClockCycleTrainingClient(scheduler, 'tenant_a') - - x = torch.randn(4, 64) - result = client.train_step(x) - - self.assertIn('loss', result) - self.assertIn('cycle_id', result) - self.assertEqual(result['batch_size'], 4) - - finally: - scheduler.stop() - model.remove_tenant('tenant_a') - - def test_multi_tenant_concurrent(self): - """Test multiple tenants submitting concurrently.""" - from twinkle.megatron.distributed.clock_cycle_scheduler import ( - ClockCycleScheduler, ClockCycleTrainingClient - ) - - model = MockMultiTenantModel(simulate_ms=0.5) - model.add_tenant('tenant_a') - model.add_tenant('tenant_b') - model.add_tenant('tenant_c') - - scheduler = ClockCycleScheduler(model, cycle_interval_ms=50) - scheduler.start() - - try: - clients = { - tid: ClockCycleTrainingClient(scheduler, tid) - for tid in ['tenant_a', 'tenant_b', 'tenant_c'] - } - - # Submit from multiple threads - results = {} - - def worker(tenant_id: str, client: ClockCycleTrainingClient): - x = torch.randn(4, 64) - return client.train_step(x) - - with ThreadPoolExecutor(max_workers=3) as executor: - futures = { - executor.submit(worker, tid, clients[tid]): tid - for tid in clients - } - for future in as_completed(futures): - tid = futures[future] - results[tid] = future.result() - - # All should succeed - for tid, result in results.items(): - self.assertIn('loss', result) - self.assertIn('cycle_id', result) - - # Check stats - stats = scheduler.get_summary_stats() - self.assertGreater(stats['total_cycles'], 0) - self.assertEqual(stats['total_samples'], 12) # 3 tenants * 4 samples - - finally: - scheduler.stop() - for tid in ['tenant_a', 'tenant_b', 'tenant_c']: - model.remove_tenant(tid) - - def test_gradient_isolation(self): - """Test that gradients are isolated between tenants.""" - from twinkle.megatron.distributed.clock_cycle_scheduler import ( - ClockCycleScheduler, ClockCycleTrainingClient - ) - - model = MockMultiTenantModel(simulate_ms=0.5) - model.add_tenant('tenant_a') - model.add_tenant('tenant_b') - - # Get initial weights - weight_a_before = model._adapters['tenant_a'].weight.clone() - weight_b_before = model._adapters['tenant_b'].weight.clone() - - scheduler = ClockCycleScheduler(model, cycle_interval_ms=20) - scheduler.start() - - try: - # Only tenant_a trains - client_a = ClockCycleTrainingClient(scheduler, 'tenant_a') - x = torch.randn(4, 64) - client_a.train_step(x) - - # tenant_a weights should change - weight_a_after = model._adapters['tenant_a'].weight - self.assertFalse(torch.allclose(weight_a_before, weight_a_after)) - - # tenant_b weights should NOT change - weight_b_after = model._adapters['tenant_b'].weight - self.assertTrue(torch.allclose(weight_b_before, weight_b_after)) - - finally: - scheduler.stop() - model.remove_tenant('tenant_a') - model.remove_tenant('tenant_b') - - def test_error_handling(self): - """Test error handling for failed requests.""" - from twinkle.megatron.distributed.clock_cycle_scheduler import ( - ClockCycleScheduler - ) - - # Create a model that raises error on forward for unknown tenant - class FailingModel(MockMultiTenantModel): - def forward(self, x): - if self._current_tenant not in self._adapters: - raise KeyError(f"Tenant '{self._current_tenant}' not found") - return super().forward(x) - - model = FailingModel() - # Don't add any tenants - requests should fail - - scheduler = ClockCycleScheduler(model, cycle_interval_ms=20) - scheduler.start() - - try: - # Submit request for non-existent tenant - future = scheduler.submit_forward_backward('nonexistent', torch.randn(4, 64)) - - # Should raise exception - with self.assertRaises(Exception): - future.result(timeout=5.0) - - finally: - scheduler.stop() - - def test_cycle_stats(self): - """Test cycle statistics collection.""" - from twinkle.megatron.distributed.clock_cycle_scheduler import ( - ClockCycleScheduler, ClockCycleTrainingClient - ) - - model = MockMultiTenantModel(simulate_ms=1.0) - model.add_tenant('tenant_a') - - scheduler = ClockCycleScheduler(model, cycle_interval_ms=50) - scheduler.start() - - try: - client = ClockCycleTrainingClient(scheduler, 'tenant_a') - - # Run multiple steps - for _ in range(3): - x = torch.randn(4, 64) - client.train_step(x) - - # Get stats - stats_list = scheduler.get_stats() - summary = scheduler.get_summary_stats() - - self.assertEqual(len(stats_list), 3) - self.assertEqual(summary['total_cycles'], 3) - self.assertEqual(summary['total_samples'], 12) - - # Check individual stats - for stat in stats_list: - self.assertGreater(stat.forward_time, 0) - self.assertGreater(stat.duration, 0) - - finally: - scheduler.stop() - model.remove_tenant('tenant_a') - - -# ============================================================================ -# Test 3: multi_tenant_ddp.py (Mock test - requires Megatron) -# ============================================================================ - -class TestMultiTenantDDP(unittest.TestCase): - """Tests for multi_tenant_ddp module (mocked).""" - - def test_tenant_ddp_state_dataclass(self): - """Test TenantDDPState dataclass.""" - from twinkle.megatron.distributed.multi_tenant_ddp import TenantDDPState - - state = TenantDDPState(tenant_id='test_tenant') - - self.assertEqual(state.tenant_id, 'test_tenant') - self.assertEqual(state.params, []) - self.assertEqual(state.buffers, []) - self.assertEqual(state.bucket_groups, []) - self.assertIsNone(state.process_group) - - @unittest.skipUnless( - False, # Skip by default - requires Megatron - "Requires Megatron-Core" - ) - def test_multi_tenant_lora_ddp_creation(self): - """Test MultiTenantLoRADDP creation (requires Megatron).""" - pass - - def test_requires_megatron(self): - """Test that MultiTenantLoRADDP requires Megatron.""" - from unittest.mock import MagicMock, patch - - with patch('twinkle.megatron.distributed.multi_tenant_ddp.MEGATRON_AVAILABLE', False): - from twinkle.megatron.distributed.multi_tenant_ddp import MultiTenantLoRADDP - - with self.assertRaises(ImportError): - MultiTenantLoRADDP( - config=MagicMock(), - ddp_config=MagicMock(), - module=nn.Linear(10, 10), - ) - - -# ============================================================================ -# Test 4: MegatronMultiAdapter -# ============================================================================ - -class TestMegatronMultiAdapter(unittest.TestCase): - """Tests for MegatronMultiAdapter.""" - - def test_adapter_context_var(self): - """Test adapter name ContextVar management.""" - from twinkle.megatron.model.multi_tenant_megatron import MegatronMultiAdapter - - # Reset state - MegatronMultiAdapter._patched = False - - # Test get/set - self.assertIsNone(MegatronMultiAdapter.get_current_adapter_name()) - MegatronMultiAdapter.set_current_adapter_name("adapter_a") - self.assertEqual(MegatronMultiAdapter.get_current_adapter_name(), "adapter_a") - MegatronMultiAdapter.set_current_adapter_name(None) - self.assertIsNone(MegatronMultiAdapter.get_current_adapter_name()) - - -# ============================================================================ -# Test 5: Integration test - tenant_context + clock_cycle_scheduler -# ============================================================================ - -class TestIntegration(unittest.TestCase): - """Integration tests combining multiple modules.""" - - def test_context_with_scheduler(self): - """Test that tenant_context works with scheduler.""" - from twinkle.megatron.distributed.tenant_context import ( - get_current_tenant, tenant_scope, set_current_tenant - ) - from twinkle.megatron.distributed.clock_cycle_scheduler import ( - ClockCycleScheduler, ClockCycleTrainingClient - ) - - model = MockMultiTenantModel(simulate_ms=0.5) - model.add_tenant('tenant_a') - model.add_tenant('tenant_b') - - scheduler = ClockCycleScheduler(model, cycle_interval_ms=30) - scheduler.start() - - try: - # Test that context propagates correctly - with tenant_scope('tenant_a'): - self.assertEqual(get_current_tenant(), 'tenant_a') - - client = ClockCycleTrainingClient(scheduler, 'tenant_a') - x = torch.randn(4, 64) - result = client.train_step(x) - - self.assertIn('loss', result) - - # Context should be cleared outside - set_current_tenant(None) - self.assertIsNone(get_current_tenant()) - - finally: - scheduler.stop() - model.remove_tenant('tenant_a') - model.remove_tenant('tenant_b') - - def test_multi_threaded_with_context(self): - """Test multi-threaded training with tenant context.""" - from twinkle.megatron.distributed.tenant_context import ( - get_current_tenant, tenant_scope - ) - from twinkle.megatron.distributed.clock_cycle_scheduler import ( - ClockCycleScheduler, ClockCycleTrainingClient - ) - - model = MockMultiTenantModel(simulate_ms=0.5) - for i in range(4): - model.add_tenant(f'tenant_{i}') - - scheduler = ClockCycleScheduler(model, cycle_interval_ms=50) - scheduler.start() - - results = {} - errors = [] - - def worker(tenant_id: str): - try: - with tenant_scope(tenant_id): - # Verify context is correct - if get_current_tenant() != tenant_id: - errors.append(f"Context mismatch for {tenant_id}") - return - - client = ClockCycleTrainingClient(scheduler, tenant_id) - x = torch.randn(4, 64) - result = client.train_step(x) - results[tenant_id] = result - except Exception as e: - errors.append(str(e)) - - try: - with ThreadPoolExecutor(max_workers=4) as executor: - futures = [ - executor.submit(worker, f'tenant_{i}') - for i in range(4) - ] - for f in futures: - f.result() - - self.assertEqual(len(errors), 0, f"Errors: {errors}") - self.assertEqual(len(results), 4) - - for tid, result in results.items(): - self.assertIn('loss', result) - - finally: - scheduler.stop() - for i in range(4): - model.remove_tenant(f'tenant_{i}') - - -# ============================================================================ -# Main -# ============================================================================ - -def run_tests(): - """Run all tests.""" - loader = unittest.TestLoader() - suite = unittest.TestSuite() - - # Add test cases - suite.addTests(loader.loadTestsFromTestCase(TestTenantContext)) - suite.addTests(loader.loadTestsFromTestCase(TestClockCycleScheduler)) - suite.addTests(loader.loadTestsFromTestCase(TestMultiTenantDDP)) - suite.addTests(loader.loadTestsFromTestCase(TestMegatronMultiAdapter)) - suite.addTests(loader.loadTestsFromTestCase(TestIntegration)) - - # Run - runner = unittest.TextTestRunner(verbosity=2) - result = runner.run(suite) - - return result.wasSuccessful() - - -if __name__ == '__main__': - success = run_tests() - exit(0 if success else 1) From 08d2020cb3cea6dca5f5e97e5ca58e8e43124534 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 19 Jan 2026 15:55:20 +0800 Subject: [PATCH 12/22] Restore megatron.py to working version 70ff0ba The modified forward_backward() was causing hangs and GEMM errors. Reverted to the clean version that works correctly with TP/PP. --- src/twinkle/model/megatron.py | 1060 +++++++++++++-------------------- 1 file changed, 401 insertions(+), 659 deletions(-) diff --git a/src/twinkle/model/megatron.py b/src/twinkle/model/megatron.py index f9171e8a..281256b4 100644 --- a/src/twinkle/model/megatron.py +++ b/src/twinkle/model/megatron.py @@ -7,20 +7,19 @@ from typing import Any, Dict, List, Literal, Optional, Type, Union import torch -import torch.distributed as dist import torch.nn as nn +import torch.distributed as dist from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler import twinkle -from twinkle import DeviceMesh, remote_class, remote_function, template +from twinkle import remote_class, remote_function, template, DeviceMesh from twinkle.data_format import InputFeature, Trajectory from twinkle.hub import HubOperation from twinkle.loss import Loss, MegatronCrossEntropyLoss from twinkle.processor import InputProcessor from twinkle.template import Template from twinkle.utils.plugin import Plugin - from .base import TwinkleModel from .strategy import MegatronStrategy @@ -30,8 +29,7 @@ from megatron.core.distributed import DistributedDataParallel as MegatronDDP from packaging import version MEGATRON_AVAILABLE = True - mcore_013 = version.parse( - megatron.core.__version__) >= version.parse('0.13.0rc0') + mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') except ImportError: MEGATRON_AVAILABLE = False mcore_013 = False @@ -40,7 +38,7 @@ @dataclass class MegatronOptimizerGroup: """Optimizer group for Megatron training. - + Similar to OptimizerGroup but adapted for Megatron's distributed training. """ adapter_name: str = None @@ -56,14 +54,8 @@ class MegatronOptimizerGroup: gradient_accumulation_steps: int = 1 cur_step: int = 0 dp_group = None - # Megatron optimizer specific fields - is_megatron_optimizer: bool = False - _last_grad_norm: float = 0.0 - _last_step_success: bool = True - def do_grad_sync(self, - gradient_accumulation_steps: Optional[int] = None - ) -> bool: + def do_grad_sync(self, gradient_accumulation_steps: Optional[int] = None) -> bool: """Check if gradient synchronization should happen.""" if gradient_accumulation_steps is None: gradient_accumulation_steps = self.gradient_accumulation_steps @@ -77,21 +69,22 @@ def check_megatron_available(): """Check if Megatron-Core is available.""" if not MEGATRON_AVAILABLE: raise ImportError( - 'Megatron-Core is not installed. Please install it with: ' - 'pip install megatron-core') + "Megatron-Core is not installed. Please install it with: " + "pip install megatron-core" + ) @remote_class(execute='all') class MegatronModel(TwinkleModel, nn.Module): """Megatron-Core model wrapper for twinkle training framework. - + Note: Uses execute='all' to create workers on all ranks, which is required for Megatron's TP/DP parallelism where all ranks must participate in collective operations like gradient all-reduce. - + This class provides a similar API to TransformersModel but uses Megatron-Core as the training backend, supporting TP/PP/CP/EP parallelism. - + Args: pretrained_model_name_or_path: HuggingFace model path or ID. device_mesh: Twinkle DeviceMesh for distributed training. @@ -104,6 +97,7 @@ class MegatronModel(TwinkleModel, nn.Module): use_distributed_optimizer: Use Megatron's distributed optimizer. **kwargs: Additional arguments passed to model initialization. """ + def __init__( self, pretrained_model_name_or_path: str, @@ -116,30 +110,28 @@ def __init__( mixed_precision: Literal['no', 'fp16', 'bf16'] = 'bf16', use_distributed_optimizer: bool = True, load_weights: bool = True, - use_megatron_bridge: - bool = True, # Use bridge-based initialization (recommended) - recompute_granularity: Optional[ - str] = 'selective', # Activation checkpointing + use_megatron_bridge: bool = True, # Use bridge-based initialization (recommended) + recompute_granularity: Optional[str] = 'selective', # Activation checkpointing recompute_modules: Optional[list] = None, # Modules to recompute **kwargs, ): check_megatron_available() nn.Module.__init__(self) - + self.model_id = pretrained_model_name_or_path self.device_mesh = device_mesh self.mixed_precision = mixed_precision self.use_megatron_bridge = use_megatron_bridge self.recompute_granularity = recompute_granularity self.recompute_modules = recompute_modules - + # Load HuggingFace config first model_path = HubOperation.download_model(pretrained_model_name_or_path) self._load_hf_config(model_path) - + # Store model_path for later use self._model_path = model_path - + # Create Megatron strategy self.strategy = MegatronStrategy( tensor_model_parallel_size=tensor_model_parallel_size, @@ -150,27 +142,25 @@ def __init__( use_distributed_optimizer=use_distributed_optimizer, mixed_precision=mixed_precision, ) - + # Initialize parallel state (skip if using bridge init, as it handles this) if not use_megatron_bridge: self.strategy.initialize() - + # Create Megatron model - self.model = self._create_megatron_model(model_path, load_weights, - **kwargs) - + self.model = self._create_megatron_model(model_path, load_weights, **kwargs) + self._model_wrapped = False # This correctly handles vocab sharding in Tensor Parallelism self.optimizer_group: Dict[str, MegatronOptimizerGroup] = { - _default_adapter_name: - MegatronOptimizerGroup(loss_instance=MegatronCrossEntropyLoss()) + _default_adapter_name: MegatronOptimizerGroup(loss_instance=MegatronCrossEntropyLoss()) } - + def _load_hf_config(self, model_path: str): """Load HuggingFace model config.""" from transformers import AutoConfig self.hf_config = AutoConfig.from_pretrained(model_path) - + def _create_megatron_model( self, model_path: str, @@ -178,12 +168,12 @@ def _create_megatron_model( **kwargs, ) -> nn.Module: """Create Megatron model from HuggingFace checkpoint. - + Args: model_path: Path to HuggingFace model. load_weights: Whether to load weights. **kwargs: Additional arguments. - + Returns: Megatron model on GPU. """ @@ -192,17 +182,15 @@ def _create_megatron_model( params_dtype = torch.float16 elif self.mixed_precision == 'no': params_dtype = torch.float32 - + if self.use_megatron_bridge: # Use bridge-based initialization (recommended) # This ensures all patches are applied and config is correctly generated - return self._create_megatron_model_with_bridge( - model_path, load_weights, params_dtype, **kwargs) + return self._create_megatron_model_with_bridge(model_path, load_weights, params_dtype, **kwargs) else: # Use twinkle's native initialization - return self._create_megatron_model_native(model_path, load_weights, - params_dtype, **kwargs) - + return self._create_megatron_model_native(model_path, load_weights, params_dtype, **kwargs) + def _create_megatron_model_with_bridge( self, model_path: str, @@ -211,25 +199,25 @@ def _create_megatron_model_with_bridge( **kwargs, ) -> nn.Module: """Create Megatron model using bridge-based initialization flow. - + This approach uses TwinkleBridgeInitializer for independent initialization It includes: - Proper config conversion from HuggingFace to Megatron format - Correct Megatron initialization (initialize_megatron) - Correct model creation - Weight loading with TwinkleGPTBridge - + Args: model_path: Path to HuggingFace model. load_weights: Whether to load weights. params_dtype: Parameter dtype. **kwargs: Additional arguments. - + Returns: Megatron model on GPU. """ from twinkle.megatron.model.bridge import TwinkleBridgeInitializer - + # Create bridge-based initializer self._bridge_initializer = TwinkleBridgeInitializer( tp_size=self.strategy.tp_size, @@ -239,30 +227,24 @@ def _create_megatron_model_with_bridge( params_dtype=params_dtype, use_cpu_initialization=False, attention_backend='flash', # Use flash for training performance - sequence_parallel=self.strategy.sequence_parallel, recompute_granularity=self.recompute_granularity, recompute_modules=self.recompute_modules, recompute_method=getattr(self, 'recompute_method', None), recompute_num_layers=getattr(self, 'recompute_num_layers', None), ) - + # Create model (this calls initialize_megatron internally) - model = self._bridge_initializer.create_model( - model_path, load_weights=load_weights) - + model = self._bridge_initializer.create_model(model_path, load_weights=load_weights) + # Update strategy state since bridge has initialized Megatron self.strategy._initialized = True self.strategy._parallel_state = mpu - - # Save transformer config for DDP wrapping - self._transformer_config = getattr(self._bridge_initializer, - '_transformer_config', None) - + # Move to GPU model = self._move_model_to_gpu(model) - + return model - + def _create_megatron_model_native( self, model_path: str, @@ -271,20 +253,20 @@ def _create_megatron_model_native( **kwargs, ) -> nn.Module: """Create Megatron model using twinkle's native initialization. - + This is the fallback method when bridge is not available. - + Args: model_path: Path to HuggingFace model. load_weights: Whether to load weights. params_dtype: Parameter dtype. **kwargs: Additional arguments. - + Returns: Megatron model on GPU. """ from twinkle.megatron.model.initializer import MegatronModelInitializer - + initializer = MegatronModelInitializer( tp_size=self.strategy.tp_size, pp_size=self.strategy.pp_size, @@ -293,44 +275,43 @@ def _create_megatron_model_native( sequence_parallel=self.strategy.sequence_parallel, params_dtype=params_dtype, ) - + # Create model model = initializer.create_gpt_model(self.hf_config, **kwargs) - + # Load weights if load_weights: initializer.load_from_hf(model, model_path, self.hf_config) - + model = self._move_model_to_gpu(model) - + return model - + def _move_model_to_gpu(self, model: nn.Module) -> nn.Module: """Move model to correct GPU device. - + This method handles moving parameters, buffers, and any cached tensors (like RoPE embeddings) to the correct device for distributed training. """ # Determine the target device based on local rank - local_rank = dist.get_rank() % torch.cuda.device_count( - ) if dist.is_initialized() else 0 + local_rank = dist.get_rank() % torch.cuda.device_count() if dist.is_initialized() else 0 device = torch.device(f'cuda:{local_rank}') - + # Set CUDA device explicitly torch.cuda.set_device(local_rank) - + # Move all parameters and buffers to GPU model = model.to(device) - + # Force synchronize to ensure all transfers complete if torch.cuda.is_available(): torch.cuda.synchronize(device) - + return model - + def _lazy_wrap_model(self): """Lazily wrap model with distributed wrapper. - + Note: This should only be called after prepare_training() has been executed on all workers. Direct calls from forward() may cause deadlocks if not all DP ranks are participating. @@ -339,10 +320,9 @@ def _lazy_wrap_model(self): # Find an optimizer from any adapter group (prefer default, then first available) optimizer = None optimizer_adapter = None - + if _default_adapter_name in self.optimizer_group: - optimizer = self.optimizer_group[ - _default_adapter_name].optimizer + optimizer = self.optimizer_group[_default_adapter_name].optimizer optimizer_adapter = _default_adapter_name else: for name, group in self.optimizer_group.items(): @@ -350,17 +330,16 @@ def _lazy_wrap_model(self): optimizer = group.optimizer optimizer_adapter = name break - + if optimizer is not None: - self.model, optimizer = self.strategy.wrap_model( - self.model, optimizer) + self.model, optimizer = self.strategy.wrap_model(self.model, optimizer) self.optimizer_group[optimizer_adapter].optimizer = optimizer self._model_wrapped = True - + @remote_function(dispatch='all') def prepare_training(self, **kwargs): """Prepare model for training. - + Note: In Ray-based Megatron training, we skip DDP wrapping to avoid deadlocks from collective operations. Each DP replica trains independently. This method still calls _lazy_wrap_model for any non-DDP setup needed. @@ -368,22 +347,20 @@ def prepare_training(self, **kwargs): self._lazy_wrap_model() @remote_function() - def forward(self, *, inputs: Union[InputFeature, List[InputFeature], - Trajectory, List[Trajectory]], - **kwargs): + def forward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs): """Forward pass with Megatron model. - + Args: inputs: Model inputs. **kwargs: Additional arguments including adapter_name. - + Returns: Model outputs. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] self._lazy_wrap_model() - + # Encode inputs if needed if isinstance(inputs, dict) and 'input_ids' not in inputs: if optimizer_config.template is not None: @@ -391,33 +368,33 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], if isinstance(inputs, list) and 'input_ids' not in inputs[0]: if optimizer_config.template is not None: inputs = optimizer_config.template.batch_encode(inputs) - + # Process inputs processor: InputProcessor = optimizer_config.processor if processor is not None: inputs: Dict[str, Any] = processor(inputs) - + labels = inputs.get('labels', None) if 'labels' in inputs: try: del inputs['labels'] except (TypeError, KeyError): pass # Some dict-like types don't support deletion - + # Forward through model outputs = self._forward_step(inputs) - + inputs['labels'] = labels optimizer_config.inputs = inputs optimizer_config.outputs = outputs return outputs - + def _forward_step(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """Execute forward step with pipeline parallelism support. - + Args: inputs: Processed inputs. - + Returns: Model outputs. """ @@ -426,16 +403,16 @@ def _forward_step(self, inputs: Dict[str, Any]) -> Dict[str, Any]: return self._forward_step_pipeline(inputs) else: return self._forward_step_simple(inputs) - + def _forward_step_simple(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """Simple forward step without pipeline parallelism.""" model = self.strategy.unwrap_model(self.model) - + # Prepare inputs for Megatron input_ids = inputs.get('input_ids') attention_mask = inputs.get('attention_mask') position_ids = inputs.get('position_ids') - + # Create position_ids if not provided if position_ids is None and input_ids is not None: position_ids = torch.arange( @@ -443,47 +420,46 @@ def _forward_step_simple(self, inputs: Dict[str, Any]) -> Dict[str, Any]: device=input_ids.device, dtype=torch.long, ).unsqueeze(0).expand(input_ids.shape[0], -1) - + # Forward pass outputs = model( input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask, ) - + return {'logits': outputs} - + def _forward_step_pipeline(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """Forward step with pipeline parallelism. - + Note: For PP > 1, the forward pass is handled by Megatron's pipeline scheduler in forward_backward(). This method is for simple forward-only inference. For training, use forward_backward() which uses get_forward_backward_func(). """ from twinkle.megatron.utils import forward_step_helper - + model = self.strategy.unwrap_model(self.model) - + # Use pipeline forward helper output = forward_step_helper( model, inputs, model.config, ) - + if output is not None: return {'logits': output} return {} @remote_function() - def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], - List[Trajectory]], **kwargs): + def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajectory]], **kwargs): """Forward pass without gradient computation. - + Args: inputs: Model inputs. **kwargs: Additional arguments. - + Returns: Model outputs. """ @@ -493,23 +469,23 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], @remote_function(collect='avg') def calculate_loss(self, **kwargs): """Calculate loss from forward outputs. - + Args: **kwargs: Additional arguments including adapter_name. - + Returns: Loss value as numpy array. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] loss_instance: Loss = optimizer_config.loss_instance - + inputs = optimizer_config.inputs outputs = optimizer_config.outputs - + assert inputs is not None and outputs is not None, \ 'Cannot calculate loss of empty inputs and outputs' - + loss_value = loss_instance(inputs, outputs, **kwargs) optimizer_config.loss_value = loss_value return loss_value.detach().cpu().float().numpy() @@ -517,191 +493,138 @@ def calculate_loss(self, **kwargs): @remote_function() def backward(self, **kwargs): """Backward pass. - + Args: **kwargs: Additional arguments. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] loss_value = optimizer_config.loss_value - + assert loss_value is not None, 'Do forwarding and calculating loss before backward' - + _gas = optimizer_config.gradient_accumulation_steps if 'gradient_accumulation_steps' in kwargs: _gas = kwargs['gradient_accumulation_steps'] - + loss_value = loss_value / _gas loss_value.backward() optimizer_config.cur_step += 1 @remote_function(dispatch='all', collect='avg', sync=True) - def forward_backward(self, - *, - inputs: Union[InputFeature, List[InputFeature], - Trajectory, List[Trajectory]], - num_microbatches: int = 1, - **kwargs): + def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs): """Combined forward and backward pass using Megatron's scheduler. - + Note: sync=True is required for Ray mode because Megatron's pipeline parallel uses NCCL P2P communication that requires all ranks to enter the function simultaneously. - + Always uses Megatron's get_forward_backward_func() which handles: - Pipeline scheduling (1F1B, interleaved, or no-pipeline) - Communication between stages (using proper process groups for multi-tenant isolation) - - Gradient accumulation across microbatches - + - Gradient accumulation + Args: - inputs: Model inputs. Can be: - - A single batch dict (num_microbatches=1) - - A list of batch dicts (num_microbatches=len(inputs)) - - An iterator yielding batch dicts - num_microbatches: Number of microbatches to process in one call. - If inputs is a list, this is inferred from len(inputs). - Using num_microbatches > 1 enables Megatron's native gradient - accumulation with better memory management and compute overlap. + inputs: Model inputs. **kwargs: Additional arguments. - + Returns: - Average loss value across all microbatches. + Loss value. """ from functools import partial from megatron.core.pipeline_parallel import get_forward_backward_func - + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] self._lazy_wrap_model() - - # Handle different input formats - # 1. Single batch dict -> wrap in list - # 2. List of batches -> use as-is - # 3. Iterator -> convert to list - if isinstance(inputs, dict): - microbatch_list = [inputs] - elif hasattr(inputs, - '__iter__') and not isinstance(inputs, (list, tuple)): - # Iterator - convert to list - microbatch_list = list(inputs) - else: - microbatch_list = list(inputs) - - # Infer num_microbatches from inputs if list is provided - if len(microbatch_list) > 1: - num_microbatches = len(microbatch_list) - - # Helper to convert list/numpy to tensor - def ensure_tensor(value): - if value is None: - return None - if isinstance(value, torch.Tensor): - return value - if isinstance(value, list): - return torch.tensor(value) - if hasattr(value, '__array__'): # numpy array - return torch.from_numpy(value) - return value - - # Process each microbatch - processed_batches = [] - for batch in microbatch_list: - # Encode inputs if needed - if isinstance(batch, dict) and 'input_ids' not in batch: - if optimizer_config.template is not None: - batch = optimizer_config.template.encode(batch) - - # Process inputs - processor = optimizer_config.processor - if processor is not None: - batch = processor(batch) - - # Ensure all tensor fields are proper tensors - if isinstance(batch, dict): - for key in ['input_ids', 'attention_mask', 'labels', 'position_ids']: - if key in batch: - batch[key] = ensure_tensor(batch[key]) - - processed_batches.append(batch) - - # Get first batch for shape info (all batches should have same shape) - first_batch = processed_batches[0] - + + # Encode inputs if needed + if isinstance(inputs, dict) and 'input_ids' not in inputs: + if optimizer_config.template is not None: + inputs = optimizer_config.template.encode(inputs) + if isinstance(inputs, list) and 'input_ids' not in inputs[0]: + if optimizer_config.template is not None: + inputs = optimizer_config.template.batch_encode(inputs) + + # Process inputs + processor = optimizer_config.processor + if processor is not None: + inputs = processor(inputs) + + # Store labels before removing from inputs + labels = inputs.get('labels', None) + if 'labels' in inputs: + try: + del inputs['labels'] + except (TypeError, KeyError): + pass # Some dict-like types don't support deletion + # Get CP size for sequence padding and splitting cp_size = self.strategy.cp_size cp_rank = mpu.get_context_parallel_rank() if cp_size > 1 else 0 - - # Get sequence length and batch size from first batch - input_ids = first_batch.get('input_ids') - if input_ids is not None and isinstance(input_ids, torch.Tensor): - original_seq_length = input_ids.shape[1] if input_ids.dim() > 1 else input_ids.shape[0] - micro_batch_size = input_ids.shape[0] if input_ids.dim() > 1 else 1 - else: - original_seq_length = 1 - micro_batch_size = 1 - + + # Get sequence length and batch size + # Note: Megatron's schedule internally divides seq_length by cp_size + # So we pass the padded full sequence length here + original_seq_length = inputs['input_ids'].shape[1] if 'input_ids' in inputs else 1 + micro_batch_size = inputs['input_ids'].shape[0] if 'input_ids' in inputs else 1 + # For CP > 1, pad seq_length to be divisible by 2*cp_size if cp_size > 1: divisor = 2 * cp_size if original_seq_length % divisor != 0: - seq_length = original_seq_length + ( - divisor - original_seq_length % divisor) + seq_length = original_seq_length + (divisor - original_seq_length % divisor) else: seq_length = original_seq_length else: seq_length = original_seq_length - + + # Move labels to GPU if needed + if labels is not None and not isinstance(labels, torch.Tensor): + labels = torch.tensor(labels, device=torch.cuda.current_device()) + elif labels is not None: + labels = labels.to(torch.cuda.current_device()) + def split_tensor_for_cp(tensor, dim=-1): """ Split tensor along sequence dimension for Context Parallel. - + With causal masking, split into 2*CP chunks and assign alternating chunks to balance workload across CP ranks. For CP rank i: chunks [i, 2*CP-1-i] """ if tensor is None or cp_size <= 1: return tensor - + if dim < 0: dim = (dim + tensor.ndim) % tensor.ndim - + seq_len = tensor.shape[dim] - + # Reshape to [batch, 2*cp_size, seq_per_chunk, ...] view_shape = list(tensor.shape) - view_shape[dim:dim + 1] = [2 * cp_size, seq_len // (2 * cp_size)] + view_shape[dim:dim+1] = [2 * cp_size, seq_len // (2 * cp_size)] reshaped = tensor.view(*view_shape) - + # Select chunks [cp_rank, 2*cp_size-1-cp_rank] - index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], - device='cpu', - pin_memory=True).cuda(non_blocking=True) + index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], + device='cpu', pin_memory=True).cuda(non_blocking=True) selected = reshaped.index_select(dim, index) - + # Reshape back: [batch, 2*seq_per_chunk, ...] out_shape = list(tensor.shape) out_shape[dim] = seq_len // cp_size return selected.reshape(*out_shape) - + # Define forward step function for Megatron # forward_step_func(data_iterator, model) -> (output_tensor, partial(loss_func)) def forward_step_func(data_iterator, model): batch = next(data_iterator) - - # Move tensors to CUDA with non_blocking=True for async transfer - def to_cuda_non_blocking(tensor): - if tensor is None: - return None - if isinstance(tensor, torch.Tensor) and not tensor.is_cuda: - return tensor.cuda(non_blocking=True) - return tensor - - input_ids = to_cuda_non_blocking(batch.get('input_ids')) - position_ids = to_cuda_non_blocking(batch.get('position_ids')) - attention_mask = to_cuda_non_blocking(batch.get('attention_mask')) - batch_labels = to_cuda_non_blocking( - batch.get('labels')) # Labels should be in each batch - + input_ids = batch.get('input_ids') + position_ids = batch.get('position_ids') + attention_mask = batch.get('attention_mask') + batch_labels = batch.get('labels', labels) # Use batch labels or passed labels + # Pad sequence for Context Parallel compatibility # Megatron's RoPE requires seq_len % (2 * cp_size) == 0 if cp_size > 1 and input_ids is not None: @@ -710,24 +633,17 @@ def to_cuda_non_blocking(tensor): if seq_len % divisor != 0: pad_len = divisor - (seq_len % divisor) # Pad input_ids - input_ids = torch.nn.functional.pad(input_ids, - (0, pad_len), - value=0) + input_ids = torch.nn.functional.pad(input_ids, (0, pad_len), value=0) # Pad labels if present if batch_labels is not None: - batch_labels = torch.nn.functional.pad(batch_labels, - (0, pad_len), - value=-100) + batch_labels = torch.nn.functional.pad(batch_labels, (0, pad_len), value=-100) # Pad attention_mask if present if attention_mask is not None: - attention_mask = torch.nn.functional.pad( - attention_mask, (0, pad_len), value=0) + attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_len), value=0) # Pad position_ids if present if position_ids is not None: - position_ids = torch.nn.functional.pad(position_ids, - (0, pad_len), - value=0) - + position_ids = torch.nn.functional.pad(position_ids, (0, pad_len), value=0) + # Create position_ids if not provided if position_ids is None and input_ids is not None: position_ids = torch.arange( @@ -735,7 +651,7 @@ def to_cuda_non_blocking(tensor): device=input_ids.device, dtype=torch.long, ).unsqueeze(0).expand(input_ids.shape[0], -1) - + # Split tensors for Context Parallel # Each CP rank processes a portion of the sequence if cp_size > 1: @@ -743,7 +659,7 @@ def to_cuda_non_blocking(tensor): position_ids = split_tensor_for_cp(position_ids, dim=-1) attention_mask = split_tensor_for_cp(attention_mask, dim=-1) batch_labels = split_tensor_for_cp(batch_labels, dim=-1) - + # Forward pass with labels - Megatron will compute loss internally # This uses Megatron's compute_language_model_loss which properly handles # vocab parallel cross entropy @@ -753,193 +669,170 @@ def to_cuda_non_blocking(tensor): attention_mask=attention_mask, labels=batch_labels, # Pass labels to let Megatron compute loss ) - + # Megatron's compute_language_model_loss returns per-token loss [batch, seq] - # We need to aggregate it with loss_mask and return 3 values for proper per-token normalization + # We need to aggregate it with loss_mask def megatron_loss_func(labels_for_mask, cp_size, output_tensor): # output_tensor is per-token loss [batch, seq] # Create loss mask from labels (ignore -100) - loss_mask = (labels_for_mask != -100) - - # Compute per-token losses - losses = output_tensor.float() - - loss_sum = torch.sum(losses * loss_mask.float()) - local_num_tokens = loss_mask.sum().to(torch.int) - - # For CP > 1, aggregate across CP ranks + loss_mask = (labels_for_mask != -100).float() + + # Flatten and compute mean + losses = output_tensor.float().view(-1) + loss_mask_flat = loss_mask.view(-1) + + # Compute local sum and count + local_loss_sum = torch.sum(losses * loss_mask_flat) + local_count = loss_mask_flat.sum() + + # For CP > 1, aggregate loss across CP ranks + # Note: Megatron's schedules.py will multiply loss by cp_group_size + # for legacy 2-output loss_func. This assumes loss_func returns SUM/cp_size (MEAN). + # So we should return local MEAN (not global MEAN) and let Megatron handle it. if cp_size > 1: - # All-reduce loss sum and token count across CP ranks - loss_tensor = torch.cat( - [loss_sum.view(1), - local_num_tokens.float().view(1)]) + # All-reduce the count across CP ranks to get total token count + # This is needed for correct averaging + total_count = local_count.clone() torch.distributed.all_reduce( - loss_tensor, + total_count, op=torch.distributed.ReduceOp.SUM, - group=mpu.get_context_parallel_group()) - loss_sum = loss_tensor[0] - local_num_tokens = loss_tensor[1].to(torch.int) - - # 1. loss (sum, will be divided by num_tokens by Megatron) - # 2. local_num_tokens (for proper averaging) - # 3. loss_dict for logging - reporting_loss = torch.cat([ - loss_sum.detach().view(1), - local_num_tokens.float().view(1) - ]) - - return (loss_sum, local_num_tokens, { - 'lm loss': reporting_loss - }) - - return output_tensor, partial(megatron_loss_func, batch_labels, - cp_size) - + group=mpu.get_context_parallel_group() + ) + + # Return local_loss_sum / total_count + # Megatron will multiply by cp_size, so the final result is: + # (local_loss_sum / total_count) * cp_size + # = (local_loss_sum * cp_size) / total_count + # But we want: SUM(local_loss_sum) / total_count + # So we need to do all_reduce on loss_sum too + total_loss_sum = local_loss_sum.clone() + torch.distributed.all_reduce( + total_loss_sum, + op=torch.distributed.ReduceOp.SUM, + group=mpu.get_context_parallel_group() + ) + + # Return global mean, but Megatron will multiply by cp_size + # So we divide by cp_size first to counteract that + loss = (total_loss_sum / total_count.clamp(min=1)) / cp_size + else: + loss = local_loss_sum / local_count.clamp(min=1) + + return loss, {'loss': loss.detach()} + + return output_tensor, partial(megatron_loss_func, batch_labels, cp_size) + # Get Megatron's forward-backward function # This automatically selects the right scheduler based on PP config: # - PP > 1: forward_backward_pipelining_without_interleaving (or with interleaving if VPP) # - PP = 1: forward_backward_no_pipelining forward_backward_func = get_forward_backward_func() - - # Create iterator over all microbatches - # Megatron's scheduler will call next(data_iterator) num_microbatches times - data_iter = iter(processed_batches) - + + # Create single-item iterator + data_iter = iter([inputs]) + # Run forward-backward with Megatron's scheduler # Megatron handles all communication internally using proper process groups - # With num_microbatches > 1, gradients are accumulated across microbatches losses = forward_backward_func( forward_step_func=forward_step_func, data_iterator=data_iter, model=[self.model], - num_microbatches=num_microbatches, + num_microbatches=1, seq_length=seq_length, micro_batch_size=micro_batch_size, forward_only=False, ) - + # Extract loss from results (only last PP stage returns non-empty) - # With 3-value loss_func return, each loss_dict contains 'lm loss': [loss_sum, num_tokens] - # We aggregate across all microbatches using proper per-token averaging - total_loss_sum = 0.0 - total_num_tokens = 0 - + loss = 0.0 + if losses: for loss_dict in losses: - if isinstance(loss_dict, dict): - # New format: 'lm loss' contains [loss_sum, num_tokens] - if 'lm loss' in loss_dict: - reporting = loss_dict['lm loss'] - if isinstance(reporting, - torch.Tensor) and reporting.numel() == 2: - total_loss_sum += reporting[0].item() - total_num_tokens += int(reporting[1].item()) - elif isinstance(reporting, - (list, tuple)) and len(reporting) == 2: - total_loss_sum += float(reporting[0]) - total_num_tokens += int(reporting[1]) - # Legacy format: 'loss' contains average loss - elif 'loss' in loss_dict: - loss_val = loss_dict['loss'] - if isinstance(loss_val, torch.Tensor): - total_loss_sum += loss_val.item() - else: - total_loss_sum += float(loss_val) - total_num_tokens += 1 # Fallback: treat as 1 sample - - # Compute average loss (per-token average across all microbatches) - if total_num_tokens > 0: - loss = total_loss_sum / total_num_tokens - else: - loss = total_loss_sum / max(num_microbatches, 1) - + if isinstance(loss_dict, dict) and 'loss' in loss_dict: + loss = loss_dict['loss'] + break + elif isinstance(loss_dict, torch.Tensor): + loss = loss_dict + break + # For PP > 1, broadcast loss from last PP stage to all ranks # Note: mpu is imported at module level, no need to reimport if mpu.get_pipeline_model_parallel_world_size() > 1: if isinstance(loss, torch.Tensor): loss_tensor = loss.detach().clone() else: - loss_tensor = torch.tensor(loss, - dtype=torch.float32, - device=torch.cuda.current_device()) - + loss_tensor = torch.tensor(loss, dtype=torch.float32, device=torch.cuda.current_device()) + # Broadcast from last PP stage (rank with pipeline_model_parallel_rank == pp_size - 1) src_rank = mpu.get_pipeline_model_parallel_last_rank() pp_group = mpu.get_pipeline_model_parallel_group() - - torch.distributed.broadcast(loss_tensor, - src=src_rank, - group=pp_group) - + + torch.distributed.broadcast( + loss_tensor, + src=src_rank, + group=pp_group + ) + loss = loss_tensor.item() - + optimizer_config.cur_step += 1 - - # Note: finalize_model_grads is called inside forward_backward_func - # which already handles gradient synchronization across DP replicas. - # No additional barrier is needed here - adding one would hurt performance. - + + # Critical: Synchronize all DP replicas before returning + # This ensures all DP replicas complete the same training step before + # moving to the next batch, preventing P2P communication deadlocks + dp_world_size = mpu.get_data_parallel_world_size() + if dp_world_size > 1: + # Use barrier on DP+CP group to synchronize all replicas + dp_cp_group = mpu.get_data_parallel_group(with_context_parallel=True) + dist.barrier(group=dp_cp_group) + if isinstance(loss, torch.Tensor): return loss.detach().cpu().float().numpy() return float(loss) @remote_function(dispatch='all') - def clip_grad_norm(self, - max_grad_norm: float = 1.0, - norm_type: int = 2, - **kwargs): + def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type: int = 2, **kwargs): """Clip gradient norm. - + Args: max_grad_norm: Maximum gradient norm. norm_type: Type of norm to use. **kwargs: Additional arguments. - + Returns: Total norm of gradients. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) - optimizer_config = self.optimizer_group[adapter_name] - - # Check if using Megatron optimizer (handles clip_grad internally) - is_megatron_opt = getattr(optimizer_config, 'is_megatron_optimizer', - False) - if is_megatron_opt: - # Megatron optimizer handles gradient clipping in step() - # Return the grad_norm from last step if available - return getattr(optimizer_config, '_last_grad_norm', 0.0) - parameters = self._get_trainable_parameters(adapter_name).values() - + return torch.nn.utils.clip_grad_norm_( - parameters, max_grad_norm, - norm_type=norm_type).detach().cpu().numpy() + parameters, max_grad_norm, norm_type=norm_type + ).detach().cpu().numpy() @remote_function(dispatch='all') def step(self, **kwargs): """Optimizer step. - + For DDP-wrapped models: - Gradients are synchronized automatically during backward via DDP - + For non-DDP models (e.g., PEFT/LoRA): - Gradients are NOT synchronized across DP ranks - Each DP replica trains independently with different data - This is a common pattern for PEFT training where the overhead of gradient averaging is not worth the benefit - + Note: Uses dispatch='all' to ensure all workers execute this method. - + Args: **kwargs: Additional arguments. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] - - if not optimizer_config.do_grad_sync( - kwargs.get('gradient_accumulation_steps')): + + if not optimizer_config.do_grad_sync(kwargs.get('gradient_accumulation_steps')): return - + # For DDP-wrapped models, gradients are already synchronized during backward if self._is_model_ddp_wrapped(): # For Megatron DDP, ensure gradient buffers are finalized @@ -947,34 +840,24 @@ def step(self, **kwargs): self.model.finish_grad_sync() # For non-DDP models (e.g., PEFT), we skip gradient synchronization # Each DP replica trains independently, which is acceptable for PEFT - + optimizer = optimizer_config.optimizer assert optimizer is not None, 'Set optimizer correctly before stepping' - - # Check if using Megatron optimizer (has different step() signature) - is_megatron_opt = getattr(optimizer_config, 'is_megatron_optimizer', - False) - if is_megatron_opt: - # Megatron optimizer step() returns (success, grad_norm, num_zeros) - success, grad_norm, num_zeros = optimizer.step() - # Store grad_norm for later retrieval - optimizer_config._last_grad_norm = grad_norm if grad_norm is not None else 0.0 - optimizer_config._last_step_success = success - else: - optimizer.step(**kwargs) - + + optimizer.step(**kwargs) + def _is_model_ddp_wrapped(self) -> bool: """Check if model is wrapped with DDP. - + Returns: - True if model is wrapped with DDP (either Megatron DDP, LoRA DDP, or PyTorch DDP). + True if model is wrapped with DDP (either Megatron DDP or PyTorch DDP). """ from torch.nn.parallel import DistributedDataParallel as TorchDDP return isinstance(self.model, (MegatronDDP, TorchDDP)) - + def _get_unwrapped_model(self) -> nn.Module: """Get the unwrapped model. - + Returns: The base model without DDP wrapper. """ @@ -983,49 +866,39 @@ def _get_unwrapped_model(self) -> nn.Module: @remote_function(dispatch='all') def zero_grad(self, **kwargs): """Zero gradients. - + For DDP-wrapped models, also zeros the DDP gradient buffers. - - Note: For DDP-wrapped models, zero_grad_buffer() is always called - because it's essential for the next training iteration. The - do_grad_sync check only affects the optimizer.zero_grad() call. - + Args: **kwargs: Additional arguments. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] - - # For DDP-wrapped models, ALWAYS zero the gradient buffer - # This is essential because Megatron's forward_backward_func uses - # the buffer's state to track gradient accumulation - if self._is_model_ddp_wrapped() and hasattr(self.model, - 'zero_grad_buffer'): - self.model.zero_grad_buffer() - - if not optimizer_config.do_grad_sync( - kwargs.get('gradient_accumulation_steps')): + + if not optimizer_config.do_grad_sync(kwargs.get('gradient_accumulation_steps')): return - + optimizer = optimizer_config.optimizer if optimizer is not None: - # Clear set_to_none for better compatibility - optimizer.zero_grad(set_to_none=True) + optimizer.zero_grad(**kwargs) + + # For Megatron DDP, zero the gradient buffer + if self._is_model_ddp_wrapped() and hasattr(self.model, 'zero_grad_buffer'): + self.model.zero_grad_buffer() @remote_function() def lr_step(self, **kwargs): """Learning rate scheduler step. - + Args: **kwargs: Additional arguments. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] - - if not optimizer_config.do_grad_sync( - kwargs.get('gradient_accumulation_steps')): + + if not optimizer_config.do_grad_sync(kwargs.get('gradient_accumulation_steps')): return - + lr_scheduler = optimizer_config.lr_scheduler if lr_scheduler is not None: lr_scheduler.step(**kwargs) @@ -1033,22 +906,22 @@ def lr_step(self, **kwargs): @remote_function(dispatch='all') def set_loss(self, loss_cls: Union[Type[Loss], str], **kwargs): """Set loss function. - + NOTE: For MegatronModel, the loss is computed internally by Megatron's GPTModel when labels are passed. This method is kept for API compatibility but the provided loss_cls is NOT used during forward_backward. - + Megatron internally uses vocab_parallel_cross_entropy which correctly handles tensor parallelism. This design ensures Loss classes don't need to be aware of the training backend (Megatron vs Transformers). - + Args: loss_cls: Loss class or string name (not used for Megatron). **kwargs: Additional arguments. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] - + if isinstance(loss_cls, str): if hasattr(twinkle.loss, loss_cls): loss_cls = getattr(twinkle.loss, loss_cls) @@ -1058,127 +931,38 @@ def set_loss(self, loss_cls: Union[Type[Loss], str], **kwargs): optimizer_config.loss_instance = loss_cls() @remote_function(dispatch='all') - def set_optimizer(self, optimizer_cls: Union[Type[Optimizer], str], - **kwargs): + def set_optimizer(self, optimizer_cls: Union[Type[Optimizer], str], **kwargs): """Set optimizer. - + Args: optimizer_cls: Optimizer class or string name. - - Standard PyTorch optimizers: 'AdamW', 'Adam', 'SGD', etc. - - 'MegatronDistributed': Use Megatron's distributed optimizer **kwargs: Additional arguments. - - For standard optimizers: lr, weight_decay, etc. - - For MegatronDistributed: use_distributed_optimizer, clip_grad, etc. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] - - # Check if requesting Megatron distributed optimizer - if optimizer_cls == 'MegatronDistributed' or kwargs.pop( - 'use_megatron_optimizer', False): - optimizer_config.optimizer = self._create_megatron_optimizer( - **kwargs) - optimizer_config.is_megatron_optimizer = True - return - + if isinstance(optimizer_cls, str): if hasattr(torch.optim, optimizer_cls): optimizer_cls = getattr(torch.optim, optimizer_cls) else: optimizer_cls = Plugin.load_plugin(optimizer_cls, Optimizer) - + optimizer_config.optimizer = optimizer_cls( - self._get_trainable_parameters(adapter_name).values(), **kwargs) - optimizer_config.is_megatron_optimizer = False - - def _create_megatron_optimizer(self, **kwargs): - """Create Megatron distributed optimizer. - - This provides significant memory savings for large models by sharding - optimizer states across DP replicas. - - Args: - **kwargs: Optimizer configuration options. - - lr: Learning rate (default: 1e-4) - - weight_decay: Weight decay (default: 0.0) - - use_distributed_optimizer: Shard optimizer states (default: True) - - clip_grad: Gradient clipping threshold (default: 1.0) - - bf16: Use bf16 training (default: True) - - adam_beta1, adam_beta2, adam_eps: Adam parameters - - Returns: - MegatronOptimizer instance. - """ - from megatron.core.optimizer import get_megatron_optimizer, OptimizerConfig - - # Build optimizer config - lr = kwargs.get('lr', 1e-4) - use_distributed_optimizer = kwargs.get('use_distributed_optimizer', - True) - - opt_config = OptimizerConfig( - optimizer='adam', - lr=lr, - min_lr=kwargs.get('min_lr', 0.0), - weight_decay=kwargs.get('weight_decay', 0.0), - adam_beta1=kwargs.get('adam_beta1', 0.9), - adam_beta2=kwargs.get('adam_beta2', 0.999), - adam_eps=kwargs.get('adam_eps', 1e-8), - clip_grad=kwargs.get('clip_grad', 1.0), - bf16=kwargs.get('bf16', True), - use_distributed_optimizer=use_distributed_optimizer, - overlap_param_gather=kwargs.get('overlap_param_gather', False), - log_num_zeros_in_grad=kwargs.get('log_num_zeros_in_grad', False), + self._get_trainable_parameters(adapter_name).values(), **kwargs ) - # For PEFT models, we need to handle the case where model is not DDP-wrapped - # We create a temporary wrapper to satisfy Megatron's optimizer requirements - model_chunks = [self.model] - - # Check if model has ddp_config (required for distributed optimizer) - if not hasattr(self.model, 'ddp_config') and use_distributed_optimizer: - # For PEFT models without DDP, fall back to non-distributed optimizer - # but still use Megatron's optimized implementation - opt_config.use_distributed_optimizer = False - if mpu.get_data_parallel_rank() == 0: - print( - 'Note: Falling back to non-distributed optimizer for PEFT model. ' - 'For distributed optimizer, wrap model with MegatronDDP.') - - try: - optimizer = get_megatron_optimizer( - config=opt_config, - model_chunks=model_chunks, - ) - return optimizer - except Exception as e: - # Fallback to simple FP32 optimizer if Megatron optimizer fails - if mpu.get_data_parallel_rank() == 0: - print( - f'Warning: Failed to create Megatron optimizer ({e}), falling back to PyTorch AdamW' - ) - - params = [p for p in self.model.parameters() if p.requires_grad] - return torch.optim.AdamW(params, - lr=lr, - weight_decay=kwargs.get( - 'weight_decay', 0.0)) - - def _get_trainable_parameters( - self, - adapter_name: str = _default_adapter_name - ) -> Dict[str, nn.Parameter]: + def _get_trainable_parameters(self, adapter_name: str = _default_adapter_name) -> Dict[str, nn.Parameter]: """Get trainable parameters. - + Args: adapter_name: Name of adapter. - + Returns: Dict mapping parameter names to parameters. """ is_default = adapter_name == _default_adapter_name pattern = re.compile(rf'\.lora_\w+\.{re.escape(adapter_name)}\.') - + params = {} model = self.strategy.unwrap_model(self.model) for name, param in model.named_parameters(): @@ -1187,24 +971,22 @@ def _get_trainable_parameters( return params @remote_function(dispatch='all') - def set_lr_scheduler(self, scheduler_cls: Union[Type[LRScheduler], str], - **kwargs): + def set_lr_scheduler(self, scheduler_cls: Union[Type[LRScheduler], str], **kwargs): """Set learning rate scheduler. - + Args: scheduler_cls: Scheduler class or string name. **kwargs: Additional arguments. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] - + if isinstance(scheduler_cls, str): if hasattr(torch.optim.lr_scheduler, scheduler_cls): - scheduler_cls = getattr(torch.optim.lr_scheduler, - scheduler_cls) + scheduler_cls = getattr(torch.optim.lr_scheduler, scheduler_cls) else: scheduler_cls = Plugin.load_plugin(scheduler_cls, LRScheduler) - + optimizer = optimizer_config.optimizer assert optimizer is not None, 'Set optimizer before setting lr_scheduler' optimizer_config.lr_scheduler = scheduler_cls(optimizer, **kwargs) @@ -1212,58 +994,57 @@ def set_lr_scheduler(self, scheduler_cls: Union[Type[LRScheduler], str], @remote_function(dispatch='all', sync=True) def save(self, output_dir: str, **kwargs): """Save model checkpoint. - + Args: output_dir: Output directory. **kwargs: Additional arguments. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) save_format = kwargs.pop('save_format', 'hf') # 'hf' or 'megatron' - + if save_format == 'hf': self._save_hf_format(output_dir, adapter_name) else: self._save_megatron_format(output_dir, adapter_name) - + self._save_tokenizer(output_dir, adapter_name) - + def _save_hf_format(self, output_dir: str, adapter_name: str): """Save in HuggingFace format using bridge adapter. - + For distributed training: - All PP ranks participate in export (each has different layers) - Only DP rank 0 actually writes to disk - Uses barrier for synchronization - + For LoRA training: - Saves in PEFT format (adapter_model.safetensors + adapter_config.json) """ from twinkle.megatron.model.bridge import TwinkleBridgeAdapter import os - + # Check if this is LoRA training (has adapter_name other than default) is_lora = adapter_name and adapter_name != '' is_peft_format = is_lora - + # Create output directory on rank 0 only try: from megatron.core import parallel_state as mpu - dp_rank = mpu.get_data_parallel_rank() if mpu.is_initialized( - ) else 0 + dp_rank = mpu.get_data_parallel_rank() if mpu.is_initialized() else 0 except (ImportError, AssertionError): dp_rank = 0 - + if dp_rank == 0: os.makedirs(output_dir, exist_ok=True) - + # Synchronize before saving if dist.is_initialized(): dist.barrier() - + # Calculate padded vocab size padded_vocab_size = self._pad_vocab_size(self.hf_config.vocab_size) \ if hasattr(self, '_pad_vocab_size') else None - + # Use TwinkleBridgeAdapter for weight conversion # All ranks participate - bridge handles which ranks write adapter = TwinkleBridgeAdapter( @@ -1271,47 +1052,42 @@ def _save_hf_format(self, output_dir: str, adapter_name: str): tp_size=self.strategy.tp_size, pp_size=self.strategy.pp_size, ep_size=self.strategy.ep_size, - model_path=self._model_path - if hasattr(self, '_model_path') else self.model_id, + model_path=self._model_path if hasattr(self, '_model_path') else self.model_id, padded_vocab_size=padded_vocab_size, ) - + # Get the model (unwrap if DDP wrapped) model = self.strategy.unwrap_model(self.model) - + # Use bridge to save weights - adapter.save_weights([model], - output_dir, - is_peft_format=is_peft_format) - + adapter.save_weights([model], output_dir, is_peft_format=is_peft_format) + # Save config on rank 0 only if dp_rank == 0: self.hf_config.save_pretrained(output_dir) - + def _pad_vocab_size(self, vocab_size: int) -> int: """Pad vocab size for tensor parallelism.""" divisor = self.strategy.tp_size * 128 return ((vocab_size + divisor - 1) // divisor) * divisor - + def _save_megatron_format(self, output_dir: str, adapter_name: str): """Save in Megatron checkpoint format.""" import os os.makedirs(output_dir, exist_ok=True) - + model = self.strategy.unwrap_model(self.model) state_dict = self._get_trainable_parameters(adapter_name) - + # Convert to CPU cpu_state_dict = {k: v.cpu() for k, v in state_dict.items()} - + # Save with rank info for distributed checkpointing rank = dist.get_rank() if dist.is_initialized() else 0 checkpoint_path = os.path.join(output_dir, f'model_rank{rank}.pt') torch.save(cpu_state_dict, checkpoint_path) - - def _save_tokenizer(self, - output_dir: str, - adapter_name: str = _default_adapter_name): + + def _save_tokenizer(self, output_dir: str, adapter_name: str = _default_adapter_name): """Save tokenizer.""" optimizer_config = self.optimizer_group.get(adapter_name) if optimizer_config and optimizer_config.template: @@ -1320,10 +1096,10 @@ def _save_tokenizer(self, @remote_function(execute='first') def get_state_dict(self, **kwargs): """Get trainable state dict. - + Args: **kwargs: Additional arguments. - + Returns: State dict of trainable parameters. """ @@ -1331,24 +1107,24 @@ def get_state_dict(self, **kwargs): return self._get_trainable_parameters(adapter_name) _peft_patched = False - + @classmethod def _patch_peft_for_megatron(cls): """Patch PEFT's BaseTuner to handle Megatron's TransformerConfig. - + Megatron's TransformerConfig doesn't have a .get() method like HuggingFace configs. This patch handles the AttributeError that occurs when PEFT tries to check tie_word_embeddings. """ if cls._peft_patched: return - + from typing import List import torch.nn as nn from peft.tuners.tuners_utils import BaseTuner - + _origin_get_tied_target_modules = BaseTuner._get_tied_target_modules - + def _get_tied_target_modules(self, model: nn.Module) -> List[str]: try: return _origin_get_tied_target_modules(self, model) @@ -1356,16 +1132,13 @@ def _get_tied_target_modules(self, model: nn.Module) -> List[str]: # Megatron's TransformerConfig doesn't have .get() method # Check share_embeddings_and_output_weights instead tied_target_modules = [] - if getattr(model, 'share_embeddings_and_output_weights', - False): + if getattr(model, 'share_embeddings_and_output_weights', False): for target_module in self.targeted_module_names: module_name = target_module.split('.')[-1] - if module_name in [ - 'output_layer', 'embedding', 'word_embeddings' - ]: + if module_name in ['output_layer', 'embedding', 'word_embeddings']: tied_target_modules.append(target_module) return tied_target_modules - + BaseTuner._get_tied_target_modules = _get_tied_target_modules cls._peft_patched = True @@ -1377,213 +1150,182 @@ def add_adapter_to_model( **kwargs, ): """Add LoRA adapter to model. - + Args: adapter_name: Name of the adapter. config_or_dir: LoRA config or path to saved adapter. **kwargs: Additional arguments. """ - from twinkle.megatron.utils import (prepare_lora_model, patch_deepcopy, - get_target_modules, - set_linear_is_expert) - + from twinkle.megatron.utils import ( + prepare_lora_model, patch_deepcopy, get_target_modules, set_linear_is_expert + ) + # Patch PEFT BaseTuner to handle Megatron's TransformerConfig # which doesn't have a .get() method like HuggingFace configs self._patch_peft_for_megatron() - + assert adapter_name, 'Use a non-empty adapter_name' - + model = self.strategy.unwrap_model(self.model) - + # Mark expert layers for MoE models set_linear_is_expert(model) - + if isinstance(config_or_dir, str): # Load from path config_or_dir = HubOperation.download_model(config_or_dir) from peft import PeftModel - model = PeftModel.from_pretrained(model, - config_or_dir, - adapter_name=adapter_name, - is_trainable=kwargs.get( - 'is_trainable', True)) + model = PeftModel.from_pretrained( + model, config_or_dir, adapter_name=adapter_name, + is_trainable=kwargs.get('is_trainable', True) + ) else: # Create from config from peft import LoraConfig, get_peft_model - + if not isinstance(config_or_dir, LoraConfig): # Convert dict to LoraConfig config_or_dir = LoraConfig(**config_or_dir) - + # Expand target_modules (e.g., 'all-linear' -> actual module names) if config_or_dir.target_modules: if isinstance(config_or_dir.target_modules, str): target_modules = [config_or_dir.target_modules] else: target_modules = list(config_or_dir.target_modules) - + expanded_modules = get_target_modules(model, target_modules) config_or_dir.target_modules = expanded_modules - + with patch_deepcopy(): - model = get_peft_model(model, - config_or_dir, - adapter_name=adapter_name) - + model = get_peft_model(model, config_or_dir, adapter_name=adapter_name) + # Update model reference if self._model_wrapped: if isinstance(self.model, MegatronDDP): self.model.module = model else: self.model = model - + # Add finish_grad_sync method for Megatron's finalize_model_grads compatibility # This is needed because Megatron's forward_backward_func calls finish_grad_sync # on model chunks, but PEFT models don't have this method by default if not hasattr(self.model, 'finish_grad_sync'): - def finish_grad_sync(): """Synchronize gradients across DP ranks for non-DDP models. - + This is a compatibility shim for Megatron's finalize_model_grads. - For PEFT/LoRA models, we manually all-reduce only trainable (LoRA) gradients. - - Optimizations: - 1. Only process gradients of trainable parameters (LoRA weights) - 2. Skip if DP size is 1 (no synchronization needed) - 3. Use coalesced all-reduce for efficiency + For PEFT/LoRA models, we manually all-reduce gradients. """ dp_world_size = mpu.get_data_parallel_world_size() - if dp_world_size <= 1: - return # No sync needed for DP=1 - - dp_cp_group = mpu.get_data_parallel_group( - with_context_parallel=True) - grads = [] - - # Only collect gradients from trainable parameters (LoRA weights) - # This is much faster than iterating all parameters - for param in self.model.parameters(): - if param.requires_grad and param.grad is not None: - grads.append(param.grad.data) - - if not grads: - return # No gradients to sync - - # Coalesced all-reduce for efficiency - from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors - coalesced = _flatten_dense_tensors(grads) - dist.all_reduce(coalesced, - op=dist.ReduceOp.AVG, - group=dp_cp_group) - - # Copy back synchronized gradients - for grad, synced in zip( - grads, _unflatten_dense_tensors(coalesced, grads)): - grad.copy_(synced) - + if dp_world_size > 1: + dp_cp_group = mpu.get_data_parallel_group(with_context_parallel=True) + grads = [] + for param in self.model.parameters(): + if param.requires_grad and param.grad is not None: + grads.append(param.grad.data) + + if grads: + from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + coalesced = _flatten_dense_tensors(grads) + dist.all_reduce(coalesced, op=dist.ReduceOp.AVG, group=dp_cp_group) + for grad, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): + grad.copy_(synced) + self.model.finish_grad_sync = finish_grad_sync - + # Create optimizer group for adapter self.optimizer_group[adapter_name] = MegatronOptimizerGroup() self.optimizer_group[adapter_name].adapter_name = adapter_name self.optimizer_group[adapter_name].adapter_config = config_or_dir - self.optimizer_group[ - adapter_name].gradient_accumulation_steps = kwargs.get( - 'gradient_accumulation_steps', 1) - + self.optimizer_group[adapter_name].gradient_accumulation_steps = kwargs.get( + 'gradient_accumulation_steps', 1 + ) + # Copy settings from default default_config = self.optimizer_group.get(_default_adapter_name) if default_config: if default_config.template: - self.optimizer_group[ - adapter_name].template = default_config.template + self.optimizer_group[adapter_name].template = default_config.template if default_config.processor: - self.optimizer_group[ - adapter_name].processor = default_config.processor + self.optimizer_group[adapter_name].processor = default_config.processor if default_config.loss_instance: - self.optimizer_group[ - adapter_name].loss_instance = default_config.loss_instance + self.optimizer_group[adapter_name].loss_instance = default_config.loss_instance @remote_function(dispatch='all') - def set_template(self, template_cls: Union[Type[template.Template], str], - **kwargs): + def set_template(self, template_cls: Union[Type[template.Template], str], **kwargs): """Set template for input encoding. - + Args: template_cls: Template class or string name. **kwargs: Additional arguments. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] - + if isinstance(template_cls, str): if hasattr(template, template_cls): template_cls = getattr(template, template_cls) else: - template_cls = Plugin.load_plugin(template_cls, - template.Template) + template_cls = Plugin.load_plugin(template_cls, template.Template) optimizer_config.template = template_cls(self.model_id, **kwargs) @remote_function(dispatch='all') - def set_processor(self, processor_cls: Union[Type[InputProcessor], str], - **kwargs): + def set_processor(self, processor_cls: Union[Type[InputProcessor], str], **kwargs): """Set input processor. - + Args: processor_cls: Processor class or string name. **kwargs: Additional arguments. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] - + if isinstance(processor_cls, str): if hasattr(twinkle.processor, processor_cls): processor_cls = getattr(twinkle.processor, processor_cls) else: - processor_cls = Plugin.load_plugin(processor_cls, - InputProcessor) - optimizer_config.processor = processor_cls( - device_mesh=self.device_mesh, **kwargs) + processor_cls = Plugin.load_plugin(processor_cls, InputProcessor) + optimizer_config.processor = processor_cls(device_mesh=self.device_mesh, **kwargs) @remote_function(execute='first') def get_train_configs(self, **kwargs): """Get training configuration summary. - + Args: **kwargs: Additional arguments. - + Returns: Configuration summary string. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] - + expr = f'Backend: Megatron-Core\n' expr += f'TP size: {self.strategy.tp_size}\n' expr += f'PP size: {self.strategy.pp_size}\n' expr += f'CP size: {self.strategy.cp_size}\n' expr += f'EP size: {self.strategy.ep_size}\n' expr += f'Sequence Parallel: {self.strategy.sequence_parallel}\n' - + if optimizer_config.adapter_config is not None: config = optimizer_config.adapter_config.__dict__ - config = { - key: str(value) - for key, value in config.items() if value is not None - } + config = {key: str(value) for key, value in config.items() if value is not None} expr += f'Adapter config:\n{json.dumps(config, indent=2, ensure_ascii=False)}\n' - + if optimizer_config.optimizer: expr += f'Optimizer: {optimizer_config.optimizer.__class__.__name__}\n' expr += f'Learning rate: {optimizer_config.optimizer.defaults.get("lr", "N/A")}\n' if optimizer_config.lr_scheduler: expr += f'LR scheduler: {optimizer_config.lr_scheduler.__class__.__name__}\n' expr += f'Gradient accumulation steps: {optimizer_config.gradient_accumulation_steps}\n' - + return expr - + def __repr__(self): - return (f"MegatronModel(model_id='{self.model_id}', " - f'tp={self.strategy.tp_size}, pp={self.strategy.pp_size}, ' - f'cp={self.strategy.cp_size}, ep={self.strategy.ep_size})') + return ( + f"MegatronModel(model_id='{self.model_id}', " + f"tp={self.strategy.tp_size}, pp={self.strategy.pp_size}, " + f"cp={self.strategy.cp_size}, ep={self.strategy.ep_size})" + ) + From df32ac132f60d04ab6e17e5e3e6d635b49f26d2a Mon Sep 17 00:00:00 2001 From: root Date: Mon, 19 Jan 2026 16:26:11 +0800 Subject: [PATCH 13/22] Fix forward_backward for dense and MoE models Key fixes: 1. Simplified input processing to match working version (70ff0ba) - Process inputs once at the beginning, not per microbatch - Properly handle labels by storing separately before deletion 2. Fixed sequence_parallel padding for MoE models - Detect actual sequence_parallel setting from model.config - Bridge auto-enables sequence_parallel for MoE with TP > 1 - Pad sequence length to be divisible by TP size 3. Reverted loss_func to return 2 values (compatible with Megatron scheduler) - Old format: (loss, {'loss': loss}) - Was incorrectly returning 3 values causing compatibility issues Tested: - Dense model (Qwen2.5-7B) with TP=2, PP=2: Step 0 loss 1.168556 - MoE model (Qwen3-30B-A3B) with TP=2, EP=2: Step 0 loss 1.474237 --- src/twinkle/model/megatron.py | 917 +++++++++++++++++++++------------- 1 file changed, 566 insertions(+), 351 deletions(-) diff --git a/src/twinkle/model/megatron.py b/src/twinkle/model/megatron.py index 281256b4..ed6062d4 100644 --- a/src/twinkle/model/megatron.py +++ b/src/twinkle/model/megatron.py @@ -7,19 +7,20 @@ from typing import Any, Dict, List, Literal, Optional, Type, Union import torch -import torch.nn as nn import torch.distributed as dist +import torch.nn as nn from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler import twinkle -from twinkle import remote_class, remote_function, template, DeviceMesh +from twinkle import DeviceMesh, remote_class, remote_function, template from twinkle.data_format import InputFeature, Trajectory from twinkle.hub import HubOperation from twinkle.loss import Loss, MegatronCrossEntropyLoss from twinkle.processor import InputProcessor from twinkle.template import Template from twinkle.utils.plugin import Plugin + from .base import TwinkleModel from .strategy import MegatronStrategy @@ -29,7 +30,8 @@ from megatron.core.distributed import DistributedDataParallel as MegatronDDP from packaging import version MEGATRON_AVAILABLE = True - mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') + mcore_013 = version.parse( + megatron.core.__version__) >= version.parse('0.13.0rc0') except ImportError: MEGATRON_AVAILABLE = False mcore_013 = False @@ -38,7 +40,7 @@ @dataclass class MegatronOptimizerGroup: """Optimizer group for Megatron training. - + Similar to OptimizerGroup but adapted for Megatron's distributed training. """ adapter_name: str = None @@ -54,8 +56,14 @@ class MegatronOptimizerGroup: gradient_accumulation_steps: int = 1 cur_step: int = 0 dp_group = None + # Megatron optimizer specific fields + is_megatron_optimizer: bool = False + _last_grad_norm: float = 0.0 + _last_step_success: bool = True - def do_grad_sync(self, gradient_accumulation_steps: Optional[int] = None) -> bool: + def do_grad_sync(self, + gradient_accumulation_steps: Optional[int] = None + ) -> bool: """Check if gradient synchronization should happen.""" if gradient_accumulation_steps is None: gradient_accumulation_steps = self.gradient_accumulation_steps @@ -69,22 +77,21 @@ def check_megatron_available(): """Check if Megatron-Core is available.""" if not MEGATRON_AVAILABLE: raise ImportError( - "Megatron-Core is not installed. Please install it with: " - "pip install megatron-core" - ) + 'Megatron-Core is not installed. Please install it with: ' + 'pip install megatron-core') @remote_class(execute='all') class MegatronModel(TwinkleModel, nn.Module): """Megatron-Core model wrapper for twinkle training framework. - + Note: Uses execute='all' to create workers on all ranks, which is required for Megatron's TP/DP parallelism where all ranks must participate in collective operations like gradient all-reduce. - + This class provides a similar API to TransformersModel but uses Megatron-Core as the training backend, supporting TP/PP/CP/EP parallelism. - + Args: pretrained_model_name_or_path: HuggingFace model path or ID. device_mesh: Twinkle DeviceMesh for distributed training. @@ -97,7 +104,6 @@ class MegatronModel(TwinkleModel, nn.Module): use_distributed_optimizer: Use Megatron's distributed optimizer. **kwargs: Additional arguments passed to model initialization. """ - def __init__( self, pretrained_model_name_or_path: str, @@ -110,28 +116,30 @@ def __init__( mixed_precision: Literal['no', 'fp16', 'bf16'] = 'bf16', use_distributed_optimizer: bool = True, load_weights: bool = True, - use_megatron_bridge: bool = True, # Use bridge-based initialization (recommended) - recompute_granularity: Optional[str] = 'selective', # Activation checkpointing + use_megatron_bridge: + bool = True, # Use bridge-based initialization (recommended) + recompute_granularity: Optional[ + str] = 'selective', # Activation checkpointing recompute_modules: Optional[list] = None, # Modules to recompute **kwargs, ): check_megatron_available() nn.Module.__init__(self) - + self.model_id = pretrained_model_name_or_path self.device_mesh = device_mesh self.mixed_precision = mixed_precision self.use_megatron_bridge = use_megatron_bridge self.recompute_granularity = recompute_granularity self.recompute_modules = recompute_modules - + # Load HuggingFace config first model_path = HubOperation.download_model(pretrained_model_name_or_path) self._load_hf_config(model_path) - + # Store model_path for later use self._model_path = model_path - + # Create Megatron strategy self.strategy = MegatronStrategy( tensor_model_parallel_size=tensor_model_parallel_size, @@ -142,25 +150,27 @@ def __init__( use_distributed_optimizer=use_distributed_optimizer, mixed_precision=mixed_precision, ) - + # Initialize parallel state (skip if using bridge init, as it handles this) if not use_megatron_bridge: self.strategy.initialize() - + # Create Megatron model - self.model = self._create_megatron_model(model_path, load_weights, **kwargs) - + self.model = self._create_megatron_model(model_path, load_weights, + **kwargs) + self._model_wrapped = False # This correctly handles vocab sharding in Tensor Parallelism self.optimizer_group: Dict[str, MegatronOptimizerGroup] = { - _default_adapter_name: MegatronOptimizerGroup(loss_instance=MegatronCrossEntropyLoss()) + _default_adapter_name: + MegatronOptimizerGroup(loss_instance=MegatronCrossEntropyLoss()) } - + def _load_hf_config(self, model_path: str): """Load HuggingFace model config.""" from transformers import AutoConfig self.hf_config = AutoConfig.from_pretrained(model_path) - + def _create_megatron_model( self, model_path: str, @@ -168,12 +178,12 @@ def _create_megatron_model( **kwargs, ) -> nn.Module: """Create Megatron model from HuggingFace checkpoint. - + Args: model_path: Path to HuggingFace model. load_weights: Whether to load weights. **kwargs: Additional arguments. - + Returns: Megatron model on GPU. """ @@ -182,15 +192,17 @@ def _create_megatron_model( params_dtype = torch.float16 elif self.mixed_precision == 'no': params_dtype = torch.float32 - + if self.use_megatron_bridge: # Use bridge-based initialization (recommended) # This ensures all patches are applied and config is correctly generated - return self._create_megatron_model_with_bridge(model_path, load_weights, params_dtype, **kwargs) + return self._create_megatron_model_with_bridge( + model_path, load_weights, params_dtype, **kwargs) else: # Use twinkle's native initialization - return self._create_megatron_model_native(model_path, load_weights, params_dtype, **kwargs) - + return self._create_megatron_model_native(model_path, load_weights, + params_dtype, **kwargs) + def _create_megatron_model_with_bridge( self, model_path: str, @@ -199,25 +211,25 @@ def _create_megatron_model_with_bridge( **kwargs, ) -> nn.Module: """Create Megatron model using bridge-based initialization flow. - + This approach uses TwinkleBridgeInitializer for independent initialization It includes: - Proper config conversion from HuggingFace to Megatron format - Correct Megatron initialization (initialize_megatron) - Correct model creation - Weight loading with TwinkleGPTBridge - + Args: model_path: Path to HuggingFace model. load_weights: Whether to load weights. params_dtype: Parameter dtype. **kwargs: Additional arguments. - + Returns: Megatron model on GPU. """ from twinkle.megatron.model.bridge import TwinkleBridgeInitializer - + # Create bridge-based initializer self._bridge_initializer = TwinkleBridgeInitializer( tp_size=self.strategy.tp_size, @@ -227,24 +239,30 @@ def _create_megatron_model_with_bridge( params_dtype=params_dtype, use_cpu_initialization=False, attention_backend='flash', # Use flash for training performance + sequence_parallel=self.strategy.sequence_parallel, recompute_granularity=self.recompute_granularity, recompute_modules=self.recompute_modules, recompute_method=getattr(self, 'recompute_method', None), recompute_num_layers=getattr(self, 'recompute_num_layers', None), ) - + # Create model (this calls initialize_megatron internally) - model = self._bridge_initializer.create_model(model_path, load_weights=load_weights) - + model = self._bridge_initializer.create_model( + model_path, load_weights=load_weights) + # Update strategy state since bridge has initialized Megatron self.strategy._initialized = True self.strategy._parallel_state = mpu - + + # Save transformer config for DDP wrapping + self._transformer_config = getattr(self._bridge_initializer, + '_transformer_config', None) + # Move to GPU model = self._move_model_to_gpu(model) - + return model - + def _create_megatron_model_native( self, model_path: str, @@ -253,20 +271,20 @@ def _create_megatron_model_native( **kwargs, ) -> nn.Module: """Create Megatron model using twinkle's native initialization. - + This is the fallback method when bridge is not available. - + Args: model_path: Path to HuggingFace model. load_weights: Whether to load weights. params_dtype: Parameter dtype. **kwargs: Additional arguments. - + Returns: Megatron model on GPU. """ from twinkle.megatron.model.initializer import MegatronModelInitializer - + initializer = MegatronModelInitializer( tp_size=self.strategy.tp_size, pp_size=self.strategy.pp_size, @@ -275,43 +293,44 @@ def _create_megatron_model_native( sequence_parallel=self.strategy.sequence_parallel, params_dtype=params_dtype, ) - + # Create model model = initializer.create_gpt_model(self.hf_config, **kwargs) - + # Load weights if load_weights: initializer.load_from_hf(model, model_path, self.hf_config) - + model = self._move_model_to_gpu(model) - + return model - + def _move_model_to_gpu(self, model: nn.Module) -> nn.Module: """Move model to correct GPU device. - + This method handles moving parameters, buffers, and any cached tensors (like RoPE embeddings) to the correct device for distributed training. """ # Determine the target device based on local rank - local_rank = dist.get_rank() % torch.cuda.device_count() if dist.is_initialized() else 0 + local_rank = dist.get_rank() % torch.cuda.device_count( + ) if dist.is_initialized() else 0 device = torch.device(f'cuda:{local_rank}') - + # Set CUDA device explicitly torch.cuda.set_device(local_rank) - + # Move all parameters and buffers to GPU model = model.to(device) - + # Force synchronize to ensure all transfers complete if torch.cuda.is_available(): torch.cuda.synchronize(device) - + return model - + def _lazy_wrap_model(self): """Lazily wrap model with distributed wrapper. - + Note: This should only be called after prepare_training() has been executed on all workers. Direct calls from forward() may cause deadlocks if not all DP ranks are participating. @@ -320,9 +339,10 @@ def _lazy_wrap_model(self): # Find an optimizer from any adapter group (prefer default, then first available) optimizer = None optimizer_adapter = None - + if _default_adapter_name in self.optimizer_group: - optimizer = self.optimizer_group[_default_adapter_name].optimizer + optimizer = self.optimizer_group[ + _default_adapter_name].optimizer optimizer_adapter = _default_adapter_name else: for name, group in self.optimizer_group.items(): @@ -330,16 +350,17 @@ def _lazy_wrap_model(self): optimizer = group.optimizer optimizer_adapter = name break - + if optimizer is not None: - self.model, optimizer = self.strategy.wrap_model(self.model, optimizer) + self.model, optimizer = self.strategy.wrap_model( + self.model, optimizer) self.optimizer_group[optimizer_adapter].optimizer = optimizer self._model_wrapped = True - + @remote_function(dispatch='all') def prepare_training(self, **kwargs): """Prepare model for training. - + Note: In Ray-based Megatron training, we skip DDP wrapping to avoid deadlocks from collective operations. Each DP replica trains independently. This method still calls _lazy_wrap_model for any non-DDP setup needed. @@ -347,20 +368,22 @@ def prepare_training(self, **kwargs): self._lazy_wrap_model() @remote_function() - def forward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs): + def forward(self, *, inputs: Union[InputFeature, List[InputFeature], + Trajectory, List[Trajectory]], + **kwargs): """Forward pass with Megatron model. - + Args: inputs: Model inputs. **kwargs: Additional arguments including adapter_name. - + Returns: Model outputs. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] self._lazy_wrap_model() - + # Encode inputs if needed if isinstance(inputs, dict) and 'input_ids' not in inputs: if optimizer_config.template is not None: @@ -368,33 +391,33 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, if isinstance(inputs, list) and 'input_ids' not in inputs[0]: if optimizer_config.template is not None: inputs = optimizer_config.template.batch_encode(inputs) - + # Process inputs processor: InputProcessor = optimizer_config.processor if processor is not None: inputs: Dict[str, Any] = processor(inputs) - + labels = inputs.get('labels', None) if 'labels' in inputs: try: del inputs['labels'] except (TypeError, KeyError): pass # Some dict-like types don't support deletion - + # Forward through model outputs = self._forward_step(inputs) - + inputs['labels'] = labels optimizer_config.inputs = inputs optimizer_config.outputs = outputs return outputs - + def _forward_step(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """Execute forward step with pipeline parallelism support. - + Args: inputs: Processed inputs. - + Returns: Model outputs. """ @@ -403,16 +426,16 @@ def _forward_step(self, inputs: Dict[str, Any]) -> Dict[str, Any]: return self._forward_step_pipeline(inputs) else: return self._forward_step_simple(inputs) - + def _forward_step_simple(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """Simple forward step without pipeline parallelism.""" model = self.strategy.unwrap_model(self.model) - + # Prepare inputs for Megatron input_ids = inputs.get('input_ids') attention_mask = inputs.get('attention_mask') position_ids = inputs.get('position_ids') - + # Create position_ids if not provided if position_ids is None and input_ids is not None: position_ids = torch.arange( @@ -420,46 +443,47 @@ def _forward_step_simple(self, inputs: Dict[str, Any]) -> Dict[str, Any]: device=input_ids.device, dtype=torch.long, ).unsqueeze(0).expand(input_ids.shape[0], -1) - + # Forward pass outputs = model( input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask, ) - + return {'logits': outputs} - + def _forward_step_pipeline(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """Forward step with pipeline parallelism. - + Note: For PP > 1, the forward pass is handled by Megatron's pipeline scheduler in forward_backward(). This method is for simple forward-only inference. For training, use forward_backward() which uses get_forward_backward_func(). """ from twinkle.megatron.utils import forward_step_helper - + model = self.strategy.unwrap_model(self.model) - + # Use pipeline forward helper output = forward_step_helper( model, inputs, model.config, ) - + if output is not None: return {'logits': output} return {} @remote_function() - def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajectory]], **kwargs): + def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], + List[Trajectory]], **kwargs): """Forward pass without gradient computation. - + Args: inputs: Model inputs. **kwargs: Additional arguments. - + Returns: Model outputs. """ @@ -469,23 +493,23 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T @remote_function(collect='avg') def calculate_loss(self, **kwargs): """Calculate loss from forward outputs. - + Args: **kwargs: Additional arguments including adapter_name. - + Returns: Loss value as numpy array. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] loss_instance: Loss = optimizer_config.loss_instance - + inputs = optimizer_config.inputs outputs = optimizer_config.outputs - + assert inputs is not None and outputs is not None, \ 'Cannot calculate loss of empty inputs and outputs' - + loss_value = loss_instance(inputs, outputs, **kwargs) optimizer_config.loss_value = loss_value return loss_value.detach().cpu().float().numpy() @@ -493,51 +517,63 @@ def calculate_loss(self, **kwargs): @remote_function() def backward(self, **kwargs): """Backward pass. - + Args: **kwargs: Additional arguments. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] loss_value = optimizer_config.loss_value - + assert loss_value is not None, 'Do forwarding and calculating loss before backward' - + _gas = optimizer_config.gradient_accumulation_steps if 'gradient_accumulation_steps' in kwargs: _gas = kwargs['gradient_accumulation_steps'] - + loss_value = loss_value / _gas loss_value.backward() optimizer_config.cur_step += 1 @remote_function(dispatch='all', collect='avg', sync=True) - def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs): + def forward_backward(self, + *, + inputs: Union[InputFeature, List[InputFeature], + Trajectory, List[Trajectory]], + num_microbatches: int = 1, + **kwargs): """Combined forward and backward pass using Megatron's scheduler. - + Note: sync=True is required for Ray mode because Megatron's pipeline parallel uses NCCL P2P communication that requires all ranks to enter the function simultaneously. - + Always uses Megatron's get_forward_backward_func() which handles: - Pipeline scheduling (1F1B, interleaved, or no-pipeline) - Communication between stages (using proper process groups for multi-tenant isolation) - - Gradient accumulation - + - Gradient accumulation across microbatches + Args: - inputs: Model inputs. + inputs: Model inputs. Can be: + - A single batch dict (num_microbatches=1) + - A list of batch dicts (num_microbatches=len(inputs)) + - An iterator yielding batch dicts + num_microbatches: Number of microbatches to process in one call. + If inputs is a list, this is inferred from len(inputs). + Using num_microbatches > 1 enables Megatron's native gradient + accumulation with better memory management and compute overlap. **kwargs: Additional arguments. - + Returns: - Loss value. + Average loss value across all microbatches. """ from functools import partial from megatron.core.pipeline_parallel import get_forward_backward_func - + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] self._lazy_wrap_model() - + # Encode inputs if needed if isinstance(inputs, dict) and 'input_ids' not in inputs: if optimizer_config.template is not None: @@ -545,12 +581,12 @@ def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Tr if isinstance(inputs, list) and 'input_ids' not in inputs[0]: if optimizer_config.template is not None: inputs = optimizer_config.template.batch_encode(inputs) - + # Process inputs processor = optimizer_config.processor if processor is not None: inputs = processor(inputs) - + # Store labels before removing from inputs labels = inputs.get('labels', None) if 'labels' in inputs: @@ -558,64 +594,76 @@ def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Tr del inputs['labels'] except (TypeError, KeyError): pass # Some dict-like types don't support deletion - - # Get CP size for sequence padding and splitting + + # Move labels to GPU if needed + if labels is not None and not isinstance(labels, torch.Tensor): + labels = torch.tensor(labels, device=torch.cuda.current_device()) + elif labels is not None: + labels = labels.to(torch.cuda.current_device()) + + # Get parallelism settings for sequence padding and splitting cp_size = self.strategy.cp_size + tp_size = self.strategy.tp_size + # Check actual sequence_parallel setting from model config + # Bridge may auto-enable sequence_parallel for MoE models + model = self.strategy.unwrap_model(self.model) + if hasattr(model, 'config') and hasattr(model.config, 'sequence_parallel'): + sequence_parallel = model.config.sequence_parallel + else: + sequence_parallel = self.strategy.sequence_parallel cp_rank = mpu.get_context_parallel_rank() if cp_size > 1 else 0 - + # Get sequence length and batch size - # Note: Megatron's schedule internally divides seq_length by cp_size - # So we pass the padded full sequence length here original_seq_length = inputs['input_ids'].shape[1] if 'input_ids' in inputs else 1 micro_batch_size = inputs['input_ids'].shape[0] if 'input_ids' in inputs else 1 - - # For CP > 1, pad seq_length to be divisible by 2*cp_size + + # Calculate padded seq_length based on parallelism requirements + # 1. For CP > 1: seq_len must be divisible by 2 * cp_size + # 2. For sequence_parallel with TP > 1: seq_len must be divisible by tp_size if cp_size > 1: divisor = 2 * cp_size - if original_seq_length % divisor != 0: - seq_length = original_seq_length + (divisor - original_seq_length % divisor) - else: - seq_length = original_seq_length + elif sequence_parallel and tp_size > 1: + divisor = tp_size else: - seq_length = original_seq_length - - # Move labels to GPU if needed - if labels is not None and not isinstance(labels, torch.Tensor): - labels = torch.tensor(labels, device=torch.cuda.current_device()) - elif labels is not None: - labels = labels.to(torch.cuda.current_device()) + divisor = 1 + if divisor > 1 and original_seq_length % divisor != 0: + seq_length = original_seq_length + (divisor - original_seq_length % divisor) + else: + seq_length = original_seq_length + def split_tensor_for_cp(tensor, dim=-1): """ Split tensor along sequence dimension for Context Parallel. - + With causal masking, split into 2*CP chunks and assign alternating chunks to balance workload across CP ranks. For CP rank i: chunks [i, 2*CP-1-i] """ if tensor is None or cp_size <= 1: return tensor - + if dim < 0: dim = (dim + tensor.ndim) % tensor.ndim - + seq_len = tensor.shape[dim] - + # Reshape to [batch, 2*cp_size, seq_per_chunk, ...] view_shape = list(tensor.shape) - view_shape[dim:dim+1] = [2 * cp_size, seq_len // (2 * cp_size)] + view_shape[dim:dim + 1] = [2 * cp_size, seq_len // (2 * cp_size)] reshaped = tensor.view(*view_shape) - + # Select chunks [cp_rank, 2*cp_size-1-cp_rank] - index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], - device='cpu', pin_memory=True).cuda(non_blocking=True) + index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], + device='cpu', + pin_memory=True).cuda(non_blocking=True) selected = reshaped.index_select(dim, index) - + # Reshape back: [batch, 2*seq_per_chunk, ...] out_shape = list(tensor.shape) out_shape[dim] = seq_len // cp_size return selected.reshape(*out_shape) - + # Define forward step function for Megatron # forward_step_func(data_iterator, model) -> (output_tensor, partial(loss_func)) def forward_step_func(data_iterator, model): @@ -624,26 +672,42 @@ def forward_step_func(data_iterator, model): position_ids = batch.get('position_ids') attention_mask = batch.get('attention_mask') batch_labels = batch.get('labels', labels) # Use batch labels or passed labels - - # Pad sequence for Context Parallel compatibility - # Megatron's RoPE requires seq_len % (2 * cp_size) == 0 - if cp_size > 1 and input_ids is not None: + + # Pad sequence for parallel compatibility + # 1. For CP > 1: Megatron's RoPE requires seq_len % (2 * cp_size) == 0 + # 2. For sequence_parallel: seq_len must be divisible by TP size + if input_ids is not None: seq_len = input_ids.shape[1] - divisor = 2 * cp_size - if seq_len % divisor != 0: + + # Calculate required divisor based on parallelism settings + if cp_size > 1: + divisor = 2 * cp_size + elif sequence_parallel and tp_size > 1: + divisor = tp_size + else: + divisor = 1 + + if divisor > 1 and seq_len % divisor != 0: pad_len = divisor - (seq_len % divisor) # Pad input_ids - input_ids = torch.nn.functional.pad(input_ids, (0, pad_len), value=0) + input_ids = torch.nn.functional.pad(input_ids, + (0, pad_len), + value=0) # Pad labels if present if batch_labels is not None: - batch_labels = torch.nn.functional.pad(batch_labels, (0, pad_len), value=-100) + batch_labels = torch.nn.functional.pad(batch_labels, + (0, pad_len), + value=-100) # Pad attention_mask if present if attention_mask is not None: - attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_len), value=0) + attention_mask = torch.nn.functional.pad( + attention_mask, (0, pad_len), value=0) # Pad position_ids if present if position_ids is not None: - position_ids = torch.nn.functional.pad(position_ids, (0, pad_len), value=0) - + position_ids = torch.nn.functional.pad(position_ids, + (0, pad_len), + value=0) + # Create position_ids if not provided if position_ids is None and input_ids is not None: position_ids = torch.arange( @@ -651,7 +715,7 @@ def forward_step_func(data_iterator, model): device=input_ids.device, dtype=torch.long, ).unsqueeze(0).expand(input_ids.shape[0], -1) - + # Split tensors for Context Parallel # Each CP rank processes a portion of the sequence if cp_size > 1: @@ -659,7 +723,7 @@ def forward_step_func(data_iterator, model): position_ids = split_tensor_for_cp(position_ids, dim=-1) attention_mask = split_tensor_for_cp(attention_mask, dim=-1) batch_labels = split_tensor_for_cp(batch_labels, dim=-1) - + # Forward pass with labels - Megatron will compute loss internally # This uses Megatron's compute_language_model_loss which properly handles # vocab parallel cross entropy @@ -669,29 +733,25 @@ def forward_step_func(data_iterator, model): attention_mask=attention_mask, labels=batch_labels, # Pass labels to let Megatron compute loss ) - + # Megatron's compute_language_model_loss returns per-token loss [batch, seq] # We need to aggregate it with loss_mask def megatron_loss_func(labels_for_mask, cp_size, output_tensor): # output_tensor is per-token loss [batch, seq] # Create loss mask from labels (ignore -100) loss_mask = (labels_for_mask != -100).float() - + # Flatten and compute mean losses = output_tensor.float().view(-1) loss_mask_flat = loss_mask.view(-1) - + # Compute local sum and count local_loss_sum = torch.sum(losses * loss_mask_flat) local_count = loss_mask_flat.sum() - + # For CP > 1, aggregate loss across CP ranks - # Note: Megatron's schedules.py will multiply loss by cp_group_size - # for legacy 2-output loss_func. This assumes loss_func returns SUM/cp_size (MEAN). - # So we should return local MEAN (not global MEAN) and let Megatron handle it. if cp_size > 1: - # All-reduce the count across CP ranks to get total token count - # This is needed for correct averaging + # All-reduce the count across CP ranks total_count = local_count.clone() torch.distributed.all_reduce( total_count, @@ -699,12 +759,7 @@ def megatron_loss_func(labels_for_mask, cp_size, output_tensor): group=mpu.get_context_parallel_group() ) - # Return local_loss_sum / total_count - # Megatron will multiply by cp_size, so the final result is: - # (local_loss_sum / total_count) * cp_size - # = (local_loss_sum * cp_size) / total_count - # But we want: SUM(local_loss_sum) / total_count - # So we need to do all_reduce on loss_sum too + # All-reduce the loss sum total_loss_sum = local_loss_sum.clone() torch.distributed.all_reduce( total_loss_sum, @@ -712,25 +767,25 @@ def megatron_loss_func(labels_for_mask, cp_size, output_tensor): group=mpu.get_context_parallel_group() ) - # Return global mean, but Megatron will multiply by cp_size - # So we divide by cp_size first to counteract that + # Return global mean, divided by cp_size to counteract Megatron's multiplication loss = (total_loss_sum / total_count.clamp(min=1)) / cp_size else: loss = local_loss_sum / local_count.clamp(min=1) - + return loss, {'loss': loss.detach()} - - return output_tensor, partial(megatron_loss_func, batch_labels, cp_size) - + + return output_tensor, partial(megatron_loss_func, batch_labels, + cp_size) + # Get Megatron's forward-backward function # This automatically selects the right scheduler based on PP config: # - PP > 1: forward_backward_pipelining_without_interleaving (or with interleaving if VPP) # - PP = 1: forward_backward_no_pipelining forward_backward_func = get_forward_backward_func() - + # Create single-item iterator data_iter = iter([inputs]) - + # Run forward-backward with Megatron's scheduler # Megatron handles all communication internally using proper process groups losses = forward_backward_func( @@ -742,10 +797,10 @@ def megatron_loss_func(labels_for_mask, cp_size, output_tensor): micro_batch_size=micro_batch_size, forward_only=False, ) - + # Extract loss from results (only last PP stage returns non-empty) loss = 0.0 - + if losses: for loss_dict in losses: if isinstance(loss_dict, dict) and 'loss' in loss_dict: @@ -754,85 +809,94 @@ def megatron_loss_func(labels_for_mask, cp_size, output_tensor): elif isinstance(loss_dict, torch.Tensor): loss = loss_dict break - + # For PP > 1, broadcast loss from last PP stage to all ranks # Note: mpu is imported at module level, no need to reimport if mpu.get_pipeline_model_parallel_world_size() > 1: if isinstance(loss, torch.Tensor): loss_tensor = loss.detach().clone() else: - loss_tensor = torch.tensor(loss, dtype=torch.float32, device=torch.cuda.current_device()) - + loss_tensor = torch.tensor(loss, + dtype=torch.float32, + device=torch.cuda.current_device()) + # Broadcast from last PP stage (rank with pipeline_model_parallel_rank == pp_size - 1) src_rank = mpu.get_pipeline_model_parallel_last_rank() pp_group = mpu.get_pipeline_model_parallel_group() - - torch.distributed.broadcast( - loss_tensor, - src=src_rank, - group=pp_group - ) - + + torch.distributed.broadcast(loss_tensor, + src=src_rank, + group=pp_group) + loss = loss_tensor.item() - + optimizer_config.cur_step += 1 - - # Critical: Synchronize all DP replicas before returning - # This ensures all DP replicas complete the same training step before - # moving to the next batch, preventing P2P communication deadlocks - dp_world_size = mpu.get_data_parallel_world_size() - if dp_world_size > 1: - # Use barrier on DP+CP group to synchronize all replicas - dp_cp_group = mpu.get_data_parallel_group(with_context_parallel=True) - dist.barrier(group=dp_cp_group) - + + # Note: finalize_model_grads is called inside forward_backward_func + # which already handles gradient synchronization across DP replicas. + # No additional barrier is needed here - adding one would hurt performance. + if isinstance(loss, torch.Tensor): return loss.detach().cpu().float().numpy() return float(loss) @remote_function(dispatch='all') - def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type: int = 2, **kwargs): + def clip_grad_norm(self, + max_grad_norm: float = 1.0, + norm_type: int = 2, + **kwargs): """Clip gradient norm. - + Args: max_grad_norm: Maximum gradient norm. norm_type: Type of norm to use. **kwargs: Additional arguments. - + Returns: Total norm of gradients. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + + # Check if using Megatron optimizer (handles clip_grad internally) + is_megatron_opt = getattr(optimizer_config, 'is_megatron_optimizer', + False) + if is_megatron_opt: + # Megatron optimizer handles gradient clipping in step() + # Return the grad_norm from last step if available + return getattr(optimizer_config, '_last_grad_norm', 0.0) + parameters = self._get_trainable_parameters(adapter_name).values() - + return torch.nn.utils.clip_grad_norm_( - parameters, max_grad_norm, norm_type=norm_type - ).detach().cpu().numpy() + parameters, max_grad_norm, + norm_type=norm_type).detach().cpu().numpy() @remote_function(dispatch='all') def step(self, **kwargs): """Optimizer step. - + For DDP-wrapped models: - Gradients are synchronized automatically during backward via DDP - + For non-DDP models (e.g., PEFT/LoRA): - Gradients are NOT synchronized across DP ranks - Each DP replica trains independently with different data - This is a common pattern for PEFT training where the overhead of gradient averaging is not worth the benefit - + Note: Uses dispatch='all' to ensure all workers execute this method. - + Args: **kwargs: Additional arguments. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] - - if not optimizer_config.do_grad_sync(kwargs.get('gradient_accumulation_steps')): + + if not optimizer_config.do_grad_sync( + kwargs.get('gradient_accumulation_steps')): return - + # For DDP-wrapped models, gradients are already synchronized during backward if self._is_model_ddp_wrapped(): # For Megatron DDP, ensure gradient buffers are finalized @@ -840,24 +904,34 @@ def step(self, **kwargs): self.model.finish_grad_sync() # For non-DDP models (e.g., PEFT), we skip gradient synchronization # Each DP replica trains independently, which is acceptable for PEFT - + optimizer = optimizer_config.optimizer assert optimizer is not None, 'Set optimizer correctly before stepping' - - optimizer.step(**kwargs) - + + # Check if using Megatron optimizer (has different step() signature) + is_megatron_opt = getattr(optimizer_config, 'is_megatron_optimizer', + False) + if is_megatron_opt: + # Megatron optimizer step() returns (success, grad_norm, num_zeros) + success, grad_norm, num_zeros = optimizer.step() + # Store grad_norm for later retrieval + optimizer_config._last_grad_norm = grad_norm if grad_norm is not None else 0.0 + optimizer_config._last_step_success = success + else: + optimizer.step(**kwargs) + def _is_model_ddp_wrapped(self) -> bool: """Check if model is wrapped with DDP. - + Returns: - True if model is wrapped with DDP (either Megatron DDP or PyTorch DDP). + True if model is wrapped with DDP (either Megatron DDP, LoRA DDP, or PyTorch DDP). """ from torch.nn.parallel import DistributedDataParallel as TorchDDP return isinstance(self.model, (MegatronDDP, TorchDDP)) - + def _get_unwrapped_model(self) -> nn.Module: """Get the unwrapped model. - + Returns: The base model without DDP wrapper. """ @@ -866,39 +940,49 @@ def _get_unwrapped_model(self) -> nn.Module: @remote_function(dispatch='all') def zero_grad(self, **kwargs): """Zero gradients. - + For DDP-wrapped models, also zeros the DDP gradient buffers. - + + Note: For DDP-wrapped models, zero_grad_buffer() is always called + because it's essential for the next training iteration. The + do_grad_sync check only affects the optimizer.zero_grad() call. + Args: **kwargs: Additional arguments. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] - - if not optimizer_config.do_grad_sync(kwargs.get('gradient_accumulation_steps')): + + # For DDP-wrapped models, ALWAYS zero the gradient buffer + # This is essential because Megatron's forward_backward_func uses + # the buffer's state to track gradient accumulation + if self._is_model_ddp_wrapped() and hasattr(self.model, + 'zero_grad_buffer'): + self.model.zero_grad_buffer() + + if not optimizer_config.do_grad_sync( + kwargs.get('gradient_accumulation_steps')): return - + optimizer = optimizer_config.optimizer if optimizer is not None: - optimizer.zero_grad(**kwargs) - - # For Megatron DDP, zero the gradient buffer - if self._is_model_ddp_wrapped() and hasattr(self.model, 'zero_grad_buffer'): - self.model.zero_grad_buffer() + # Clear set_to_none for better compatibility + optimizer.zero_grad(set_to_none=True) @remote_function() def lr_step(self, **kwargs): """Learning rate scheduler step. - + Args: **kwargs: Additional arguments. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] - - if not optimizer_config.do_grad_sync(kwargs.get('gradient_accumulation_steps')): + + if not optimizer_config.do_grad_sync( + kwargs.get('gradient_accumulation_steps')): return - + lr_scheduler = optimizer_config.lr_scheduler if lr_scheduler is not None: lr_scheduler.step(**kwargs) @@ -906,22 +990,22 @@ def lr_step(self, **kwargs): @remote_function(dispatch='all') def set_loss(self, loss_cls: Union[Type[Loss], str], **kwargs): """Set loss function. - + NOTE: For MegatronModel, the loss is computed internally by Megatron's GPTModel when labels are passed. This method is kept for API compatibility but the provided loss_cls is NOT used during forward_backward. - + Megatron internally uses vocab_parallel_cross_entropy which correctly handles tensor parallelism. This design ensures Loss classes don't need to be aware of the training backend (Megatron vs Transformers). - + Args: loss_cls: Loss class or string name (not used for Megatron). **kwargs: Additional arguments. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] - + if isinstance(loss_cls, str): if hasattr(twinkle.loss, loss_cls): loss_cls = getattr(twinkle.loss, loss_cls) @@ -931,38 +1015,127 @@ def set_loss(self, loss_cls: Union[Type[Loss], str], **kwargs): optimizer_config.loss_instance = loss_cls() @remote_function(dispatch='all') - def set_optimizer(self, optimizer_cls: Union[Type[Optimizer], str], **kwargs): + def set_optimizer(self, optimizer_cls: Union[Type[Optimizer], str], + **kwargs): """Set optimizer. - + Args: optimizer_cls: Optimizer class or string name. + - Standard PyTorch optimizers: 'AdamW', 'Adam', 'SGD', etc. + - 'MegatronDistributed': Use Megatron's distributed optimizer **kwargs: Additional arguments. + - For standard optimizers: lr, weight_decay, etc. + - For MegatronDistributed: use_distributed_optimizer, clip_grad, etc. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] - + + # Check if requesting Megatron distributed optimizer + if optimizer_cls == 'MegatronDistributed' or kwargs.pop( + 'use_megatron_optimizer', False): + optimizer_config.optimizer = self._create_megatron_optimizer( + **kwargs) + optimizer_config.is_megatron_optimizer = True + return + if isinstance(optimizer_cls, str): if hasattr(torch.optim, optimizer_cls): optimizer_cls = getattr(torch.optim, optimizer_cls) else: optimizer_cls = Plugin.load_plugin(optimizer_cls, Optimizer) - + optimizer_config.optimizer = optimizer_cls( - self._get_trainable_parameters(adapter_name).values(), **kwargs + self._get_trainable_parameters(adapter_name).values(), **kwargs) + optimizer_config.is_megatron_optimizer = False + + def _create_megatron_optimizer(self, **kwargs): + """Create Megatron distributed optimizer. + + This provides significant memory savings for large models by sharding + optimizer states across DP replicas. + + Args: + **kwargs: Optimizer configuration options. + - lr: Learning rate (default: 1e-4) + - weight_decay: Weight decay (default: 0.0) + - use_distributed_optimizer: Shard optimizer states (default: True) + - clip_grad: Gradient clipping threshold (default: 1.0) + - bf16: Use bf16 training (default: True) + - adam_beta1, adam_beta2, adam_eps: Adam parameters + + Returns: + MegatronOptimizer instance. + """ + from megatron.core.optimizer import get_megatron_optimizer, OptimizerConfig + + # Build optimizer config + lr = kwargs.get('lr', 1e-4) + use_distributed_optimizer = kwargs.get('use_distributed_optimizer', + True) + + opt_config = OptimizerConfig( + optimizer='adam', + lr=lr, + min_lr=kwargs.get('min_lr', 0.0), + weight_decay=kwargs.get('weight_decay', 0.0), + adam_beta1=kwargs.get('adam_beta1', 0.9), + adam_beta2=kwargs.get('adam_beta2', 0.999), + adam_eps=kwargs.get('adam_eps', 1e-8), + clip_grad=kwargs.get('clip_grad', 1.0), + bf16=kwargs.get('bf16', True), + use_distributed_optimizer=use_distributed_optimizer, + overlap_param_gather=kwargs.get('overlap_param_gather', False), + log_num_zeros_in_grad=kwargs.get('log_num_zeros_in_grad', False), ) - def _get_trainable_parameters(self, adapter_name: str = _default_adapter_name) -> Dict[str, nn.Parameter]: + # For PEFT models, we need to handle the case where model is not DDP-wrapped + # We create a temporary wrapper to satisfy Megatron's optimizer requirements + model_chunks = [self.model] + + # Check if model has ddp_config (required for distributed optimizer) + if not hasattr(self.model, 'ddp_config') and use_distributed_optimizer: + # For PEFT models without DDP, fall back to non-distributed optimizer + # but still use Megatron's optimized implementation + opt_config.use_distributed_optimizer = False + if mpu.get_data_parallel_rank() == 0: + print( + 'Note: Falling back to non-distributed optimizer for PEFT model. ' + 'For distributed optimizer, wrap model with MegatronDDP.') + + try: + optimizer = get_megatron_optimizer( + config=opt_config, + model_chunks=model_chunks, + ) + return optimizer + except Exception as e: + # Fallback to simple FP32 optimizer if Megatron optimizer fails + if mpu.get_data_parallel_rank() == 0: + print( + f'Warning: Failed to create Megatron optimizer ({e}), falling back to PyTorch AdamW' + ) + + params = [p for p in self.model.parameters() if p.requires_grad] + return torch.optim.AdamW(params, + lr=lr, + weight_decay=kwargs.get( + 'weight_decay', 0.0)) + + def _get_trainable_parameters( + self, + adapter_name: str = _default_adapter_name + ) -> Dict[str, nn.Parameter]: """Get trainable parameters. - + Args: adapter_name: Name of adapter. - + Returns: Dict mapping parameter names to parameters. """ is_default = adapter_name == _default_adapter_name pattern = re.compile(rf'\.lora_\w+\.{re.escape(adapter_name)}\.') - + params = {} model = self.strategy.unwrap_model(self.model) for name, param in model.named_parameters(): @@ -971,22 +1144,24 @@ def _get_trainable_parameters(self, adapter_name: str = _default_adapter_name) - return params @remote_function(dispatch='all') - def set_lr_scheduler(self, scheduler_cls: Union[Type[LRScheduler], str], **kwargs): + def set_lr_scheduler(self, scheduler_cls: Union[Type[LRScheduler], str], + **kwargs): """Set learning rate scheduler. - + Args: scheduler_cls: Scheduler class or string name. **kwargs: Additional arguments. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] - + if isinstance(scheduler_cls, str): if hasattr(torch.optim.lr_scheduler, scheduler_cls): - scheduler_cls = getattr(torch.optim.lr_scheduler, scheduler_cls) + scheduler_cls = getattr(torch.optim.lr_scheduler, + scheduler_cls) else: scheduler_cls = Plugin.load_plugin(scheduler_cls, LRScheduler) - + optimizer = optimizer_config.optimizer assert optimizer is not None, 'Set optimizer before setting lr_scheduler' optimizer_config.lr_scheduler = scheduler_cls(optimizer, **kwargs) @@ -994,57 +1169,58 @@ def set_lr_scheduler(self, scheduler_cls: Union[Type[LRScheduler], str], **kwarg @remote_function(dispatch='all', sync=True) def save(self, output_dir: str, **kwargs): """Save model checkpoint. - + Args: output_dir: Output directory. **kwargs: Additional arguments. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) save_format = kwargs.pop('save_format', 'hf') # 'hf' or 'megatron' - + if save_format == 'hf': self._save_hf_format(output_dir, adapter_name) else: self._save_megatron_format(output_dir, adapter_name) - + self._save_tokenizer(output_dir, adapter_name) - + def _save_hf_format(self, output_dir: str, adapter_name: str): """Save in HuggingFace format using bridge adapter. - + For distributed training: - All PP ranks participate in export (each has different layers) - Only DP rank 0 actually writes to disk - Uses barrier for synchronization - + For LoRA training: - Saves in PEFT format (adapter_model.safetensors + adapter_config.json) """ from twinkle.megatron.model.bridge import TwinkleBridgeAdapter import os - + # Check if this is LoRA training (has adapter_name other than default) is_lora = adapter_name and adapter_name != '' is_peft_format = is_lora - + # Create output directory on rank 0 only try: from megatron.core import parallel_state as mpu - dp_rank = mpu.get_data_parallel_rank() if mpu.is_initialized() else 0 + dp_rank = mpu.get_data_parallel_rank() if mpu.is_initialized( + ) else 0 except (ImportError, AssertionError): dp_rank = 0 - + if dp_rank == 0: os.makedirs(output_dir, exist_ok=True) - + # Synchronize before saving if dist.is_initialized(): dist.barrier() - + # Calculate padded vocab size padded_vocab_size = self._pad_vocab_size(self.hf_config.vocab_size) \ if hasattr(self, '_pad_vocab_size') else None - + # Use TwinkleBridgeAdapter for weight conversion # All ranks participate - bridge handles which ranks write adapter = TwinkleBridgeAdapter( @@ -1052,42 +1228,47 @@ def _save_hf_format(self, output_dir: str, adapter_name: str): tp_size=self.strategy.tp_size, pp_size=self.strategy.pp_size, ep_size=self.strategy.ep_size, - model_path=self._model_path if hasattr(self, '_model_path') else self.model_id, + model_path=self._model_path + if hasattr(self, '_model_path') else self.model_id, padded_vocab_size=padded_vocab_size, ) - + # Get the model (unwrap if DDP wrapped) model = self.strategy.unwrap_model(self.model) - + # Use bridge to save weights - adapter.save_weights([model], output_dir, is_peft_format=is_peft_format) - + adapter.save_weights([model], + output_dir, + is_peft_format=is_peft_format) + # Save config on rank 0 only if dp_rank == 0: self.hf_config.save_pretrained(output_dir) - + def _pad_vocab_size(self, vocab_size: int) -> int: """Pad vocab size for tensor parallelism.""" divisor = self.strategy.tp_size * 128 return ((vocab_size + divisor - 1) // divisor) * divisor - + def _save_megatron_format(self, output_dir: str, adapter_name: str): """Save in Megatron checkpoint format.""" import os os.makedirs(output_dir, exist_ok=True) - + model = self.strategy.unwrap_model(self.model) state_dict = self._get_trainable_parameters(adapter_name) - + # Convert to CPU cpu_state_dict = {k: v.cpu() for k, v in state_dict.items()} - + # Save with rank info for distributed checkpointing rank = dist.get_rank() if dist.is_initialized() else 0 checkpoint_path = os.path.join(output_dir, f'model_rank{rank}.pt') torch.save(cpu_state_dict, checkpoint_path) - - def _save_tokenizer(self, output_dir: str, adapter_name: str = _default_adapter_name): + + def _save_tokenizer(self, + output_dir: str, + adapter_name: str = _default_adapter_name): """Save tokenizer.""" optimizer_config = self.optimizer_group.get(adapter_name) if optimizer_config and optimizer_config.template: @@ -1096,10 +1277,10 @@ def _save_tokenizer(self, output_dir: str, adapter_name: str = _default_adapter_ @remote_function(execute='first') def get_state_dict(self, **kwargs): """Get trainable state dict. - + Args: **kwargs: Additional arguments. - + Returns: State dict of trainable parameters. """ @@ -1107,24 +1288,24 @@ def get_state_dict(self, **kwargs): return self._get_trainable_parameters(adapter_name) _peft_patched = False - + @classmethod def _patch_peft_for_megatron(cls): """Patch PEFT's BaseTuner to handle Megatron's TransformerConfig. - + Megatron's TransformerConfig doesn't have a .get() method like HuggingFace configs. This patch handles the AttributeError that occurs when PEFT tries to check tie_word_embeddings. """ if cls._peft_patched: return - + from typing import List import torch.nn as nn from peft.tuners.tuners_utils import BaseTuner - + _origin_get_tied_target_modules = BaseTuner._get_tied_target_modules - + def _get_tied_target_modules(self, model: nn.Module) -> List[str]: try: return _origin_get_tied_target_modules(self, model) @@ -1132,13 +1313,16 @@ def _get_tied_target_modules(self, model: nn.Module) -> List[str]: # Megatron's TransformerConfig doesn't have .get() method # Check share_embeddings_and_output_weights instead tied_target_modules = [] - if getattr(model, 'share_embeddings_and_output_weights', False): + if getattr(model, 'share_embeddings_and_output_weights', + False): for target_module in self.targeted_module_names: module_name = target_module.split('.')[-1] - if module_name in ['output_layer', 'embedding', 'word_embeddings']: + if module_name in [ + 'output_layer', 'embedding', 'word_embeddings' + ]: tied_target_modules.append(target_module) return tied_target_modules - + BaseTuner._get_tied_target_modules = _get_tied_target_modules cls._peft_patched = True @@ -1150,182 +1334,213 @@ def add_adapter_to_model( **kwargs, ): """Add LoRA adapter to model. - + Args: adapter_name: Name of the adapter. config_or_dir: LoRA config or path to saved adapter. **kwargs: Additional arguments. """ - from twinkle.megatron.utils import ( - prepare_lora_model, patch_deepcopy, get_target_modules, set_linear_is_expert - ) - + from twinkle.megatron.utils import (prepare_lora_model, patch_deepcopy, + get_target_modules, + set_linear_is_expert) + # Patch PEFT BaseTuner to handle Megatron's TransformerConfig # which doesn't have a .get() method like HuggingFace configs self._patch_peft_for_megatron() - + assert adapter_name, 'Use a non-empty adapter_name' - + model = self.strategy.unwrap_model(self.model) - + # Mark expert layers for MoE models set_linear_is_expert(model) - + if isinstance(config_or_dir, str): # Load from path config_or_dir = HubOperation.download_model(config_or_dir) from peft import PeftModel - model = PeftModel.from_pretrained( - model, config_or_dir, adapter_name=adapter_name, - is_trainable=kwargs.get('is_trainable', True) - ) + model = PeftModel.from_pretrained(model, + config_or_dir, + adapter_name=adapter_name, + is_trainable=kwargs.get( + 'is_trainable', True)) else: # Create from config from peft import LoraConfig, get_peft_model - + if not isinstance(config_or_dir, LoraConfig): # Convert dict to LoraConfig config_or_dir = LoraConfig(**config_or_dir) - + # Expand target_modules (e.g., 'all-linear' -> actual module names) if config_or_dir.target_modules: if isinstance(config_or_dir.target_modules, str): target_modules = [config_or_dir.target_modules] else: target_modules = list(config_or_dir.target_modules) - + expanded_modules = get_target_modules(model, target_modules) config_or_dir.target_modules = expanded_modules - + with patch_deepcopy(): - model = get_peft_model(model, config_or_dir, adapter_name=adapter_name) - + model = get_peft_model(model, + config_or_dir, + adapter_name=adapter_name) + # Update model reference if self._model_wrapped: if isinstance(self.model, MegatronDDP): self.model.module = model else: self.model = model - + # Add finish_grad_sync method for Megatron's finalize_model_grads compatibility # This is needed because Megatron's forward_backward_func calls finish_grad_sync # on model chunks, but PEFT models don't have this method by default if not hasattr(self.model, 'finish_grad_sync'): + def finish_grad_sync(): """Synchronize gradients across DP ranks for non-DDP models. - + This is a compatibility shim for Megatron's finalize_model_grads. - For PEFT/LoRA models, we manually all-reduce gradients. + For PEFT/LoRA models, we manually all-reduce only trainable (LoRA) gradients. + + Optimizations: + 1. Only process gradients of trainable parameters (LoRA weights) + 2. Skip if DP size is 1 (no synchronization needed) + 3. Use coalesced all-reduce for efficiency """ dp_world_size = mpu.get_data_parallel_world_size() - if dp_world_size > 1: - dp_cp_group = mpu.get_data_parallel_group(with_context_parallel=True) - grads = [] - for param in self.model.parameters(): - if param.requires_grad and param.grad is not None: - grads.append(param.grad.data) - - if grads: - from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors - coalesced = _flatten_dense_tensors(grads) - dist.all_reduce(coalesced, op=dist.ReduceOp.AVG, group=dp_cp_group) - for grad, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): - grad.copy_(synced) - + if dp_world_size <= 1: + return # No sync needed for DP=1 + + dp_cp_group = mpu.get_data_parallel_group( + with_context_parallel=True) + grads = [] + + # Only collect gradients from trainable parameters (LoRA weights) + # This is much faster than iterating all parameters + for param in self.model.parameters(): + if param.requires_grad and param.grad is not None: + grads.append(param.grad.data) + + if not grads: + return # No gradients to sync + + # Coalesced all-reduce for efficiency + from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + coalesced = _flatten_dense_tensors(grads) + dist.all_reduce(coalesced, + op=dist.ReduceOp.AVG, + group=dp_cp_group) + + # Copy back synchronized gradients + for grad, synced in zip( + grads, _unflatten_dense_tensors(coalesced, grads)): + grad.copy_(synced) + self.model.finish_grad_sync = finish_grad_sync - + # Create optimizer group for adapter self.optimizer_group[adapter_name] = MegatronOptimizerGroup() self.optimizer_group[adapter_name].adapter_name = adapter_name self.optimizer_group[adapter_name].adapter_config = config_or_dir - self.optimizer_group[adapter_name].gradient_accumulation_steps = kwargs.get( - 'gradient_accumulation_steps', 1 - ) - + self.optimizer_group[ + adapter_name].gradient_accumulation_steps = kwargs.get( + 'gradient_accumulation_steps', 1) + # Copy settings from default default_config = self.optimizer_group.get(_default_adapter_name) if default_config: if default_config.template: - self.optimizer_group[adapter_name].template = default_config.template + self.optimizer_group[ + adapter_name].template = default_config.template if default_config.processor: - self.optimizer_group[adapter_name].processor = default_config.processor + self.optimizer_group[ + adapter_name].processor = default_config.processor if default_config.loss_instance: - self.optimizer_group[adapter_name].loss_instance = default_config.loss_instance + self.optimizer_group[ + adapter_name].loss_instance = default_config.loss_instance @remote_function(dispatch='all') - def set_template(self, template_cls: Union[Type[template.Template], str], **kwargs): + def set_template(self, template_cls: Union[Type[template.Template], str], + **kwargs): """Set template for input encoding. - + Args: template_cls: Template class or string name. **kwargs: Additional arguments. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] - + if isinstance(template_cls, str): if hasattr(template, template_cls): template_cls = getattr(template, template_cls) else: - template_cls = Plugin.load_plugin(template_cls, template.Template) + template_cls = Plugin.load_plugin(template_cls, + template.Template) optimizer_config.template = template_cls(self.model_id, **kwargs) @remote_function(dispatch='all') - def set_processor(self, processor_cls: Union[Type[InputProcessor], str], **kwargs): + def set_processor(self, processor_cls: Union[Type[InputProcessor], str], + **kwargs): """Set input processor. - + Args: processor_cls: Processor class or string name. **kwargs: Additional arguments. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] - + if isinstance(processor_cls, str): if hasattr(twinkle.processor, processor_cls): processor_cls = getattr(twinkle.processor, processor_cls) else: - processor_cls = Plugin.load_plugin(processor_cls, InputProcessor) - optimizer_config.processor = processor_cls(device_mesh=self.device_mesh, **kwargs) + processor_cls = Plugin.load_plugin(processor_cls, + InputProcessor) + optimizer_config.processor = processor_cls( + device_mesh=self.device_mesh, **kwargs) @remote_function(execute='first') def get_train_configs(self, **kwargs): """Get training configuration summary. - + Args: **kwargs: Additional arguments. - + Returns: Configuration summary string. """ adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] - + expr = f'Backend: Megatron-Core\n' expr += f'TP size: {self.strategy.tp_size}\n' expr += f'PP size: {self.strategy.pp_size}\n' expr += f'CP size: {self.strategy.cp_size}\n' expr += f'EP size: {self.strategy.ep_size}\n' expr += f'Sequence Parallel: {self.strategy.sequence_parallel}\n' - + if optimizer_config.adapter_config is not None: config = optimizer_config.adapter_config.__dict__ - config = {key: str(value) for key, value in config.items() if value is not None} + config = { + key: str(value) + for key, value in config.items() if value is not None + } expr += f'Adapter config:\n{json.dumps(config, indent=2, ensure_ascii=False)}\n' - + if optimizer_config.optimizer: expr += f'Optimizer: {optimizer_config.optimizer.__class__.__name__}\n' expr += f'Learning rate: {optimizer_config.optimizer.defaults.get("lr", "N/A")}\n' if optimizer_config.lr_scheduler: expr += f'LR scheduler: {optimizer_config.lr_scheduler.__class__.__name__}\n' expr += f'Gradient accumulation steps: {optimizer_config.gradient_accumulation_steps}\n' - + return expr - - def __repr__(self): - return ( - f"MegatronModel(model_id='{self.model_id}', " - f"tp={self.strategy.tp_size}, pp={self.strategy.pp_size}, " - f"cp={self.strategy.cp_size}, ep={self.strategy.ep_size})" - ) + def __repr__(self): + return (f"MegatronModel(model_id='{self.model_id}', " + f'tp={self.strategy.tp_size}, pp={self.strategy.pp_size}, ' + f'cp={self.strategy.cp_size}, ep={self.strategy.ep_size})') From 20f67a8c184d91377eec7b2c339a0cfcf755cfea Mon Sep 17 00:00:00 2001 From: root Date: Mon, 19 Jan 2026 17:24:41 +0800 Subject: [PATCH 14/22] Add DP barrier for consistent synchronization in forward_backward For DP > 1, added a barrier at the end of forward_backward to ensure all DP replicas complete the same training step before moving to the next batch. This prevents P2P communication deadlocks in subsequent training iterations. --- src/twinkle/model/megatron.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/twinkle/model/megatron.py b/src/twinkle/model/megatron.py index ed6062d4..71d86750 100644 --- a/src/twinkle/model/megatron.py +++ b/src/twinkle/model/megatron.py @@ -582,7 +582,7 @@ def forward_backward(self, if optimizer_config.template is not None: inputs = optimizer_config.template.batch_encode(inputs) - # Process inputs + # Process inputs (collate list to batched dict) processor = optimizer_config.processor if processor is not None: inputs = processor(inputs) @@ -832,9 +832,14 @@ def megatron_loss_func(labels_for_mask, cp_size, output_tensor): optimizer_config.cur_step += 1 - # Note: finalize_model_grads is called inside forward_backward_func - # which already handles gradient synchronization across DP replicas. - # No additional barrier is needed here - adding one would hurt performance. + # Critical: Synchronize all DP replicas before returning + # This ensures all DP replicas complete the same training step before + # moving to the next batch, preventing P2P communication deadlocks + dp_world_size = mpu.get_data_parallel_world_size() + if dp_world_size > 1: + # Use barrier on DP+CP group to synchronize all replicas + dp_cp_group = mpu.get_data_parallel_group(with_context_parallel=True) + dist.barrier(group=dp_cp_group) if isinstance(loss, torch.Tensor): return loss.detach().cpu().float().numpy() From 6a589e9875448cf1ecc3129a128d294b59eb2cf8 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 19 Jan 2026 17:30:55 +0800 Subject: [PATCH 15/22] Fix save condition to avoid saving at step 0 --- cookbook/megatron/lora.py | 2 +- cookbook/megatron/moe_lora.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cookbook/megatron/lora.py b/cookbook/megatron/lora.py index 8af6556e..870c607b 100644 --- a/cookbook/megatron/lora.py +++ b/cookbook/megatron/lora.py @@ -163,7 +163,7 @@ def train(): model.step(adapter_name=adapter_name) model.zero_grad(adapter_name=adapter_name) model.lr_step(adapter_name=adapter_name) - if step % 100 == 0: + if step > 0 and step % 100 == 0: model.save('./output/megatron_lora', adapter_name=adapter_name) # Early stop for testing if args.max_steps and step >= args.max_steps * 16: diff --git a/cookbook/megatron/moe_lora.py b/cookbook/megatron/moe_lora.py index c8556cc8..272d7f41 100644 --- a/cookbook/megatron/moe_lora.py +++ b/cookbook/megatron/moe_lora.py @@ -200,7 +200,7 @@ def train(): model.step(adapter_name=adapter_name) model.zero_grad(adapter_name=adapter_name) model.lr_step(adapter_name=adapter_name) - if step % 100 == 0: + if step > 0 and step % 100 == 0: model.save('./output/megatron_moe_lora', adapter_name=adapter_name) # Early stop for testing if args.max_steps and step >= args.max_steps * 16: From 83c82dce52c04ef03c37a7ff855790ee3f28c5f0 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 19 Jan 2026 17:36:54 +0800 Subject: [PATCH 16/22] fix merge --- cookbook/sft/streaming_dataset.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/cookbook/sft/streaming_dataset.py b/cookbook/sft/streaming_dataset.py index ced0c5bb..1eba29fd 100644 --- a/cookbook/sft/streaming_dataset.py +++ b/cookbook/sft/streaming_dataset.py @@ -25,17 +25,7 @@ mesh_dim_names=('dp', 'fsdp') ) -<<<<<<< HEAD -#device_mesh = DeviceMesh( -# device_type='cuda', -# mesh=np.array([0,1,2,3]), -# mesh_dim_names=('dp',) -#) - -twinkle.initialize(mode='ray', nproc_per_node=4, groups=device_group, global_device_mesh=device_mesh, lazy_collect=False) -======= twinkle.initialize(mode='ray', groups=device_group, global_device_mesh=device_mesh) ->>>>>>> origin/dev def create_dataset(): From 450c2043574eda99f85311a727097c654d6201a9 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 19 Jan 2026 18:15:17 +0800 Subject: [PATCH 17/22] fix --- src/twinkle/infra/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/twinkle/infra/__init__.py b/src/twinkle/infra/__init__.py index e79726f8..4677f80d 100644 --- a/src/twinkle/infra/__init__.py +++ b/src/twinkle/infra/__init__.py @@ -543,6 +543,11 @@ def wrapper(self, *args, **kwargs) -> T1: import ray for _res in result: # raise when any worker raises StopIteration + resolved_results = ray.get(result) + for _res in resolved_results: + stop = _res[1] + if stop: + raise StopIteration() stop = ray.get(_res[1]) if stop: raise StopIteration() From 72ed37ef3052e88a6953f2bcd9553285ba1c9efd Mon Sep 17 00:00:00 2001 From: root Date: Mon, 19 Jan 2026 18:29:37 +0800 Subject: [PATCH 18/22] fix --- src/twinkle/infra/__init__.py | 53 ++++++++++++++--------------------- src/twinkle/model/megatron.py | 4 +-- 2 files changed, 23 insertions(+), 34 deletions(-) diff --git a/src/twinkle/infra/__init__.py b/src/twinkle/infra/__init__.py index 4677f80d..d157569e 100644 --- a/src/twinkle/infra/__init__.py +++ b/src/twinkle/infra/__init__.py @@ -525,38 +525,27 @@ def wrapper(self, *args, **kwargs) -> T1: from ._ray import RayHelper _workers_and_args = _dispatch_args(_get_workers(self._actors, execute), dispatch, execute, device_mesh, args, kwargs) - - # Use sync execution for methods requiring NCCL synchronization - if sync: - result = RayHelper.execute_all_sync(func.__name__, _workers_and_args) - return _collect_func(collect, result) - else: - result = RayHelper.execute_all_async(func.__name__, _workers_and_args) - result_func = RayHelper.do_get_and_collect_func(_collect_func, collect, result) - lazy_collect = _lazy_collect - if hasattr(self, '_lazy_collect'): - lazy_collect = self._lazy_collect - result = result_func if lazy_collect else result_func() - if func.__name__ == '__iter__': - return self - if func.__name__ == '__next__': - import ray - for _res in result: - # raise when any worker raises StopIteration - resolved_results = ray.get(result) - for _res in resolved_results: - stop = _res[1] - if stop: - raise StopIteration() - stop = ray.get(_res[1]) - if stop: - raise StopIteration() - result = [_res[0] for _res in result] - result_func._futures = result - if hasattr(self, '_lazy_collect'): - lazy_collect = self._lazy_collect - result = result_func if lazy_collect else result_func() - return result + execute_method = RayHelper.execute_all_async if not sync else RayHelper.execute_all_sync + result = execute_method(func.__name__, _workers_and_args) + result_func = RayHelper.do_get_and_collect_func(_collect_func, collect, result) + lazy_collect = _lazy_collect + if func.__name__ == '__iter__': + return self + + if func.__name__ == '__next__': + import ray + for _res in result: + # raise when any worker raises StopIteration + stop = ray.get(_res[1]) + if stop: + raise StopIteration() + result = [_res[0] for _res in result] + result_func._futures = result + + if hasattr(self, '_lazy_collect'): + lazy_collect = self._lazy_collect + result = result_func if lazy_collect else result_func() + return result else: raise NotImplementedError(f'Unsupported mode {_mode}') diff --git a/src/twinkle/model/megatron.py b/src/twinkle/model/megatron.py index 71d86750..be2926cb 100644 --- a/src/twinkle/model/megatron.py +++ b/src/twinkle/model/megatron.py @@ -490,7 +490,7 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], with torch.no_grad(): return self.forward(inputs=inputs, **kwargs) - @remote_function(collect='avg') + @remote_function(collect='mean') def calculate_loss(self, **kwargs): """Calculate loss from forward outputs. @@ -535,7 +535,7 @@ def backward(self, **kwargs): loss_value.backward() optimizer_config.cur_step += 1 - @remote_function(dispatch='all', collect='avg', sync=True) + @remote_function(dispatch='all', collect='mean', sync=True) def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], From 67bb5d9ec44fb9355dc8471dbad8c8907136142e Mon Sep 17 00:00:00 2001 From: root Date: Mon, 19 Jan 2026 18:36:23 +0800 Subject: [PATCH 19/22] fix --- cookbook/megatron/lora.py | 4 ++-- src/twinkle/megatron/model/bridge.py | 11 +++-------- src/twinkle/megatron/model/initializer.py | 10 ++-------- src/twinkle/megatron/tuners/lora.py | 5 +---- 4 files changed, 8 insertions(+), 22 deletions(-) diff --git a/cookbook/megatron/lora.py b/cookbook/megatron/lora.py index 870c607b..cac2ba0c 100644 --- a/cookbook/megatron/lora.py +++ b/cookbook/megatron/lora.py @@ -182,8 +182,8 @@ def cleanup(): from megatron.core import parallel_state as mpu if mpu.is_initialized(): mpu.destroy_model_parallel() - except Exception: - pass + except Exception as e: + logger.warning(f"Error during cleanup: {e}") if dist.is_initialized(): dist.destroy_process_group() diff --git a/src/twinkle/megatron/model/bridge.py b/src/twinkle/megatron/model/bridge.py index ea1face1..cc16d356 100644 --- a/src/twinkle/megatron/model/bridge.py +++ b/src/twinkle/megatron/model/bridge.py @@ -423,7 +423,7 @@ def __init__(self, self.ep_group = mpu.get_expert_model_parallel_group() self.etp_rank = mpu.get_expert_tensor_parallel_rank() self.etp_group = mpu.get_expert_tensor_parallel_group() - except: + except (AttributeError, AssertionError): self.ep_rank = 0 self.ep_group = None self.etp_rank = 0 @@ -1927,13 +1927,8 @@ def finalize_model_grads_for_lora(model, moe_grouped_gemm=moe_grouped_gemm, qk_layernorm=mg_config_dict.get('qk_layernorm', False), ) - except Exception: - from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec - layer_spec = get_gpt_layer_local_spec( - num_experts=mg_config_dict.get('num_experts'), - moe_grouped_gemm=moe_grouped_gemm, - qk_layernorm=mg_config_dict.get('qk_layernorm', False), - ) + except (ImportError, AttributeError): + raise RuntimeError("TransformerEngine is not installed or not compatible with this version of Megatron-Core.") # Create model max_seq_length = getattr(hf_config, 'max_position_embeddings', 4096) diff --git a/src/twinkle/megatron/model/initializer.py b/src/twinkle/megatron/model/initializer.py index 76b7b976..215a4eb5 100644 --- a/src/twinkle/megatron/model/initializer.py +++ b/src/twinkle/megatron/model/initializer.py @@ -249,14 +249,8 @@ def _get_layer_spec(self, config: 'TransformerConfig'): qk_layernorm=qk_layernorm, multi_latent_attention=multi_latent_attention, ) - except Exception: - # Fallback to local spec without TE - return get_gpt_layer_local_spec( - num_experts=num_experts, - moe_grouped_gemm=moe_grouped_gemm, - qk_layernorm=qk_layernorm, - multi_latent_attention=multi_latent_attention, - ) + except (ImportError, AttributeError): + raise RuntimeError("TransformerEngine is not installed or not compatible with this version of Megatron-Core.") def load_from_hf( self, diff --git a/src/twinkle/megatron/tuners/lora.py b/src/twinkle/megatron/tuners/lora.py index 22075f7f..a9d29b73 100644 --- a/src/twinkle/megatron/tuners/lora.py +++ b/src/twinkle/megatron/tuners/lora.py @@ -639,7 +639,4 @@ def dispatch_megatron( # Register dispatch function with PEFT -try: - model.dispatch_megatron = dispatch_megatron -except Exception: - pass +model.dispatch_megatron = dispatch_megatron From 9954b7c9c6f05310f83090beb46e89ba1864f9bf Mon Sep 17 00:00:00 2001 From: root Date: Mon, 19 Jan 2026 18:42:25 +0800 Subject: [PATCH 20/22] fix --- cookbook/megatron/moe_lora.py | 4 ++-- src/twinkle/megatron/model/bridge.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/cookbook/megatron/moe_lora.py b/cookbook/megatron/moe_lora.py index 272d7f41..6f31e9e1 100644 --- a/cookbook/megatron/moe_lora.py +++ b/cookbook/megatron/moe_lora.py @@ -219,8 +219,8 @@ def cleanup(): from megatron.core import parallel_state as mpu if mpu.is_initialized(): mpu.destroy_model_parallel() - except Exception: - pass + except Exception as e: + logger.warning(f"Error during cleanup: {e}") if dist.is_initialized(): dist.destroy_process_group() diff --git a/src/twinkle/megatron/model/bridge.py b/src/twinkle/megatron/model/bridge.py index cc16d356..4aab500e 100644 --- a/src/twinkle/megatron/model/bridge.py +++ b/src/twinkle/megatron/model/bridge.py @@ -905,8 +905,8 @@ def _load_moe(self, mg_layer, loader: SafetensorLoader, layer_idx: int): experts_module.weight1.data.copy_(fc1_stacked) else: # Handle TP split - tp_rank = self.tp_rank - tp_size = self.tp_size + tp_rank = self.etp_rank + tp_size = self.etp_size if tp_size > 1: # Split along last dim for weight1 chunk_size = fc1_stacked.shape[1] // tp_size @@ -921,8 +921,8 @@ def _load_moe(self, mg_layer, loader: SafetensorLoader, layer_idx: int): experts_module.weight2.data.copy_(fc2_stacked) else: # Handle TP split - tp_rank = self.tp_rank - tp_size = self.tp_size + tp_rank = self.etp_rank + tp_size = self.etp_size if tp_size > 1: # Split along first dim for weight2 chunk_size = fc2_stacked.shape[0] // tp_size From 9ebede2e23ea398912209924750dd170fe75a440 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 19 Jan 2026 19:44:51 +0800 Subject: [PATCH 21/22] fix ep --- cookbook/megatron/moe_lora.py | 36 ++++++++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/cookbook/megatron/moe_lora.py b/cookbook/megatron/moe_lora.py index 6f31e9e1..f4208c56 100644 --- a/cookbook/megatron/moe_lora.py +++ b/cookbook/megatron/moe_lora.py @@ -91,27 +91,45 @@ def train(): else: WORLD_SIZE = args.num_gpus - # For MoE with EP: Total parallelism = TP * PP * CP * EP * DP - # EP is placed between CP and DP in Megatron's order - DP_SIZE = WORLD_SIZE // (TP_SIZE * PP_SIZE * CP_SIZE * EP_SIZE) + # DP calculation follows Megatron's logic: DP = world_size / (TP * PP * CP) + # EP is NOT included in DP calculation - it's handled separately by Megatron + # for MoE expert layers. Expert data parallel size is computed internally by Megatron. + DP_SIZE = WORLD_SIZE // (TP_SIZE * PP_SIZE * CP_SIZE) + # Validate that world size supports the parallelism config + # For MoE, EP must divide the data parallel replicas correctly if DP_SIZE < 1: raise ValueError( f'Not enough GPUs ({WORLD_SIZE}) for parallelism config: ' - f'TP={TP_SIZE}, PP={PP_SIZE}, CP={CP_SIZE}, EP={EP_SIZE}. ' - f'Required: {TP_SIZE * PP_SIZE * CP_SIZE * EP_SIZE}') + f'TP={TP_SIZE}, PP={PP_SIZE}, CP={CP_SIZE}. ' + f'Required at least: {TP_SIZE * PP_SIZE * CP_SIZE}') + + # EP should divide into world_size / (TP * PP) for proper expert parallelism + # This ensures expert_data_parallel_size = world_size / (ETP * EP * PP) is valid + expert_data_parallel_size = WORLD_SIZE // (TP_SIZE * EP_SIZE * PP_SIZE) + if expert_data_parallel_size < 1: + raise ValueError( + f'Not enough GPUs ({WORLD_SIZE}) for expert parallelism: ' + f'TP={TP_SIZE}, PP={PP_SIZE}, EP={EP_SIZE}. ' + f'Required at least: {TP_SIZE * EP_SIZE * PP_SIZE}') logger.info( f'Parallelism config: TP={TP_SIZE}, PP={PP_SIZE}, CP={CP_SIZE}, EP={EP_SIZE}, DP={DP_SIZE}' ) + logger.info( + f'Expert data parallel size: {expert_data_parallel_size}' + ) # Device mesh: Match Megatron's order "tp-cp-ep-dp-pp" from innermost to outermost - # Shape: (PP, DP, EP, CP, TP) + # Note: EP is not a separate dimension in the device mesh because: + # 1. Megatron handles EP internally in initialize_model_parallel() + # 2. For non-expert layers, DP = world_size / (TP * PP * CP) + # 3. For expert layers, expert_data_parallel_size = world_size / (ETP * EP * PP) + # The device mesh is used by twinkle for data sharding, which follows DP_SIZE device_mesh = DeviceMesh( device_type='cuda', - mesh=np.arange(WORLD_SIZE).reshape(PP_SIZE, DP_SIZE, EP_SIZE, CP_SIZE, - TP_SIZE), - mesh_dim_names=('pp', 'dp', 'ep', 'cp', 'tp'), + mesh=np.arange(WORLD_SIZE).reshape(PP_SIZE, DP_SIZE, CP_SIZE, TP_SIZE), + mesh_dim_names=('pp', 'dp', 'cp', 'tp'), ) # Device group name - used as remote_group in Ray mode From 2b7b4b8787de8395676eb6387d38ad344e5dc956 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 19 Jan 2026 19:49:50 +0800 Subject: [PATCH 22/22] fix demo --- cookbook/megatron/lora.py | 12 ++++++------ cookbook/megatron/moe_lora.py | 14 +++++++------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/cookbook/megatron/lora.py b/cookbook/megatron/lora.py index cac2ba0c..74a39156 100644 --- a/cookbook/megatron/lora.py +++ b/cookbook/megatron/lora.py @@ -45,6 +45,7 @@ parser.add_argument('--model', type=str, default='ms://Qwen/Qwen2.5-7B-Instruct') +GAS = 16 # gradient accumulation steps args = parser.parse_args() # Set mode in environment before importing twinkle @@ -142,7 +143,7 @@ def train(): adapter_name = 'lora' model.add_adapter_to_model(adapter_name, lora_config, - gradient_accumulation_steps=16) + gradient_accumulation_steps=GAS) model.set_template('Qwen3Template', adapter_name=adapter_name) model.set_processor(InputProcessor, padding_side='right', @@ -157,22 +158,21 @@ def train(): for step, batch in enumerate(dataloader): output = model.forward_backward(inputs=batch, adapter_name=adapter_name) - if step % 16 == 0: + if step % GAS == 0: logger.info(f'Step {step // 16}, loss: {output}') model.clip_grad_norm(1.0, adapter_name=adapter_name) model.step(adapter_name=adapter_name) model.zero_grad(adapter_name=adapter_name) model.lr_step(adapter_name=adapter_name) - if step > 0 and step % 100 == 0: + if step > 0 and step % (100 * GAS) == 0: model.save('./output/megatron_lora', adapter_name=adapter_name) # Early stop for testing - if args.max_steps and step >= args.max_steps * 16: + if args.max_steps and step >= args.max_steps * GAS: logger.info(f'Reached max_steps ({args.max_steps}), stopping.') break - + model.save('./output/megatron_lora', adapter_name=adapter_name) logger.info('Training completed!') - def cleanup(): """Clean up distributed resources.""" import torch.distributed as dist diff --git a/cookbook/megatron/moe_lora.py b/cookbook/megatron/moe_lora.py index f4208c56..1cb72a7e 100644 --- a/cookbook/megatron/moe_lora.py +++ b/cookbook/megatron/moe_lora.py @@ -27,7 +27,7 @@ from twinkle.loss import MegatronCrossEntropyLoss from twinkle.model import MegatronModel from twinkle.processor import InputProcessor - +GAS = 16 # gradient accumulation steps # Parse arguments first to determine mode parser = argparse.ArgumentParser() parser.add_argument('--mode', @@ -197,7 +197,7 @@ def train(): adapter_name = 'lora' model.add_adapter_to_model(adapter_name, lora_config, - gradient_accumulation_steps=16) + gradient_accumulation_steps=GAS) model.set_template('Qwen3Template', adapter_name=adapter_name) model.set_processor(InputProcessor, padding_side='right', @@ -212,19 +212,19 @@ def train(): for step, batch in enumerate(dataloader): output = model.forward_backward(inputs=batch, adapter_name=adapter_name) - if step % 16 == 0: - logger.info(f'Step {step // 16}, loss: {output}') + if step % GAS == 0: + logger.info(f'Step {step // GAS}, loss: {output}') model.clip_grad_norm(1.0, adapter_name=adapter_name) model.step(adapter_name=adapter_name) model.zero_grad(adapter_name=adapter_name) model.lr_step(adapter_name=adapter_name) - if step > 0 and step % 100 == 0: + if step > 0 and step % (100 * GAS) == 0: model.save('./output/megatron_moe_lora', adapter_name=adapter_name) # Early stop for testing - if args.max_steps and step >= args.max_steps * 16: + if args.max_steps and step >= args.max_steps * GAS: logger.info(f'Reached max_steps ({args.max_steps}), stopping.') break - + model.save('./output/megatron_moe_lora', adapter_name=adapter_name) logger.info('Training completed!')