diff --git a/cookbook/legacy/grpo/lora.py b/cookbook/legacy/grpo/lora.py index 92b13eeb..6746370e 100644 --- a/cookbook/legacy/grpo/lora.py +++ b/cookbook/legacy/grpo/lora.py @@ -1,29 +1,26 @@ """ -GRPO Training Cookbook - Hybrid Mode with LoRA +GRPO Training Cookbook - Standalone Mode with LoRA This cookbook demonstrates GRPO training using TransformersModel and VLLMSampler -in hybrid mode (model and sampler colocated on same GPUs with IPC weight sync). +in standalone mode (model and sampler on different GPUs with NCCL weight sync). Task: Countdown Game - Given numbers [a, b, c, d], find an equation using +, -, *, / that equals target - Rewards: format reward (/ tags) + accuracy reward (correct equation) -Reference: swift/docs/source/BestPractices/GRPO.md - Usage: - SWANLAB_API_KEY=xxx python lora.py + SWANLAB_API_KEY=xxx python cookbook/grpo/lora.py """ import os import re import time +import numpy as np from typing import List, Dict, Any, Tuple from dataclasses import dataclass, field from contextlib import contextmanager -os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" -os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' -os.environ["CUDA_VISIBLE_DEVICES"] = "4,5,6,7" +os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" from peft import LoraConfig import torch @@ -31,7 +28,7 @@ import twinkle from twinkle import DeviceMesh, DeviceGroup, Platform, get_device_placement, get_logger from twinkle import remote_class, remote_function -from twinkle.data_format import Trajectory, Message +from twinkle.data_format import Trajectory, Message, InputFeature from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.model import TransformersModel @@ -39,7 +36,7 @@ from twinkle.sampler import VLLMSampler from twinkle.sampler.types import SamplingParams, SampleResponse from twinkle.rl import GRPOAdvantage -from twinkle.weight_loader import IPCWeightLoader +from twinkle.checkpoint_engine import CheckpointEngineManager from twinkle.template import Template from transformers import AutoTokenizer @@ -55,35 +52,22 @@ # ========== Configuration ========== MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen2.5-3B-Instruct') NUM_GPUS = int(os.environ.get('NUM_GPUS', 4)) -NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8)) +MODEL_GPUS = int(os.environ.get('MODEL_GPUS', NUM_GPUS // 2)) +SAMPLER_GPUS = NUM_GPUS - MODEL_GPUS +NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 4)) MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 1024)) LEARNING_RATE = float(os.environ.get('LR', 1e-5)) GRPO_EPSILON = float(os.environ.get('GRPO_EPSILON', 0.2)) GRPO_BETA = float(os.environ.get('GRPO_BETA', 0.0)) MAX_STEPS = int(os.environ.get('MAX_STEPS', 2000)) -BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4)) -GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 8)) +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 2)) +GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1)) TEMPERATURE = float(os.environ.get('TEMPERATURE', 1.0)) WEIGHT_SYNC_INTERVAL = int(os.environ.get('WEIGHT_SYNC_INTERVAL', 1)) -SYSTEM_PROMPT = ( - "You are a helpful assistant. You first thinks about the reasoning process " - "in the mind and then provides the user with the answer." -) ADAPTER_NAME = 'default' -# ========== Profiling Context ========== -@contextmanager -def profiling_context(name: str): - """Context manager for timing and logging.""" - start_time = time.perf_counter() - yield - duration = time.perf_counter() - start_time - if USE_SWANLAB: - swanlab.log({f'profiling/Time taken: GRPOTrainer.{name}': duration}) - - # ========== Metrics ========== @dataclass class TrainingMetrics: @@ -96,7 +80,7 @@ class TrainingMetrics: completion_lengths: List[int] = field(default_factory=list) loss: float = 0.0 grad_norm: float = 0.0 - + def reset(self): self.generate_time = 0.0 self.weight_sync_time = 0.0 @@ -106,7 +90,7 @@ def reset(self): self.completion_lengths = [] self.loss = 0.0 self.grad_norm = 0.0 - + def to_log_dict(self, step: int) -> Dict[str, float]: log_dict = { 'step': step, @@ -115,28 +99,19 @@ def to_log_dict(self, step: int) -> Dict[str, float]: 'train/loss': self.loss, 'train/grad_norm': self.grad_norm, } - if self.rewards: log_dict['train/reward'] = sum(self.rewards) / len(self.rewards) log_dict['train/reward_std'] = torch.tensor(self.rewards).std().item() if len(self.rewards) > 1 else 0.0 - if self.format_rewards: log_dict['train/rewards/Format/mean'] = sum(self.format_rewards) / len(self.format_rewards) - log_dict['train/rewards/Format/std'] = torch.tensor(self.format_rewards).std().item() if len(self.format_rewards) > 1 else 0.0 - if self.accuracy_rewards: log_dict['train/rewards/CountdownORM/mean'] = sum(self.accuracy_rewards) / len(self.accuracy_rewards) - log_dict['train/rewards/CountdownORM/std'] = torch.tensor(self.accuracy_rewards).std().item() if len(self.accuracy_rewards) > 1 else 0.0 - if self.completion_lengths: log_dict['train/completions/mean_length'] = sum(self.completion_lengths) / len(self.completion_lengths) - log_dict['train/completions/min_length'] = min(self.completion_lengths) - log_dict['train/completions/max_length'] = max(self.completion_lengths) - return log_dict -# ========== Reward Functions ========== +# ========== Rewards ========== def format_reward(completion: str) -> float: """Format reward: checks and tags.""" has_think = bool(re.search(r".*?", completion, re.DOTALL)) @@ -150,83 +125,27 @@ def countdown_accuracy_reward(completion: str, target: int, nums: List[int]) -> match = re.search(r'(.*?)<\/answer>', completion) if match is None: return 0.0 - equation = match.group(1).strip() if '=' in equation: equation = equation.split('=')[0] - used_numbers = [int(n) for n in re.findall(r'\d+', equation)] if sorted(used_numbers) != sorted(nums): return 0.0 - - allowed_pattern = r'^[\d+\-*/().\s]+$' - if not re.match(allowed_pattern, equation): + if not re.match(r'^[\d+\-*/().\s]+$', equation): return 0.0 - result = eval(equation, {'__builtins__': None}, {}) return 1.0 if abs(float(result) - float(target)) < 1e-5 else 0.0 except Exception: return 0.0 -def compute_rewards(trajectories: List[Trajectory]) -> Tuple[List[float], List[float], List[float]]: - """Compute format and accuracy rewards from trajectories. - - Args: - trajectories: List of trajectories with 'messages' and 'user_data'. - - Returns: - Tuple of (total_rewards, format_rewards, accuracy_rewards). - """ - total_rewards, format_rewards, accuracy_rewards = [], [], [] - - for traj in trajectories: - messages = traj.get('messages', []) - completion = "" - for msg in reversed(messages): - if msg.get('role') == 'assistant': - completion = msg.get('content', '') - break - - user_data = traj.get('user_data', [{}]) - data = user_data[0] if isinstance(user_data, list) and user_data else {} - target = data.get('target', 0) - nums = data.get('nums', []) - - fmt_reward = format_reward(completion) - acc_reward = countdown_accuracy_reward(completion, target, nums) - - format_rewards.append(fmt_reward) - accuracy_rewards.append(acc_reward) - total_rewards.append(fmt_reward + acc_reward) - - return total_rewards, format_rewards, accuracy_rewards - - # ========== Dataset ========== def create_countdown_dataset(): """Create Countdown Game dataset.""" - def countdown_processor(row: Dict[str, Any]) -> Dict[str, Any]: - nums = row.get('nums', []) - target = row.get('response', row.get('target', 0)) - - query = f"""Using the numbers {nums}, create an equation that equals {target}. -You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. -Show your work in tags. And return the final equation and answer in tags, -for example (1 + 2) / 3 * 4 = 4 .""" - - return { - 'messages': [ - {'role': 'system', 'content': SYSTEM_PROMPT}, - {'role': 'user', 'content': query}, - {'role': 'assistant', 'content': ''}, - ], - 'user_data': [{'target': target, 'nums': nums}], - } - + from twinkle.preprocessor import CountdownProcessor dataset = Dataset(DatasetMeta("ms://zouxuhong/Countdown-Tasks-3to4", data_slice=range(50000))) dataset.set_template("Template", model_id=MODEL_ID, max_length=8192) - dataset.map(countdown_processor) + dataset.map(CountdownProcessor()) return dataset @@ -236,51 +155,102 @@ def process_samples( sample_response: SampleResponse, tokenizer, num_generations: int, -) -> List[Tuple[Trajectory, List[float], int]]: - """ - Process sampled responses into (trajectory, old_logps, length) tuples. - - Args: - prompts: List of original prompts (P prompts). - sample_response: Response containing sequences (P * num_generations sequences). - tokenizer: Tokenizer for decoding. - num_generations: Number of generations per prompt (G). - + template: Template, +) -> Tuple[List[Trajectory], List[InputFeature], List[List[float]], List[int]]: + """Process sampled responses. + + Builds ``InputFeature`` directly by concatenating prompt token ids with + the sampler's raw response token ids, avoiding decode/re-encode drift. + Returns: - List of (trajectory, old_logps, length) tuples. - The list has P * G entries, organized as: - [prompt0_gen0, prompt0_gen1, ..., prompt1_gen0, prompt1_gen1, ...] + (trajectories, input_features, old_logps_list, completion_lengths) """ - results = [] + trajectories: List[Trajectory] = [] + input_features: List[InputFeature] = [] + old_logps_list: List[List[float]] = [] + completion_lengths: List[int] = [] + sequences = sample_response.sequences - - # Sequences are organized as: for each prompt, num_generations sequences + prompt_ids_cache: Dict[int, List[int]] = {} + for i, prompt in enumerate(prompts): + if i not in prompt_ids_cache: + prompt_messages = [ + dict(msg) for msg in prompt.get('messages', []) + if not (msg.get('role') == 'assistant' + and not msg.get('content', '').strip()) + ] + encoded = tokenizer.apply_chat_template( + prompt_messages, tokenize=True, add_generation_prompt=True, + ) + if hasattr(encoded, 'tolist'): + encoded = encoded.tolist() + prompt_ids_cache[i] = list(encoded) + + prompt_ids = prompt_ids_cache[i] + for j in range(num_generations): seq_idx = i * num_generations + j if seq_idx >= len(sequences): - logger.warning(f"Expected {len(prompts) * num_generations} sequences, got {len(sequences)}") + logger.warning( + f"Expected {len(prompts) * num_generations} sequences, " + f"got {len(sequences)}" + ) break - + seq = sequences[seq_idx] - response_tokens = seq.tokens + response_tokens = list(seq.tokens) response_logprobs = seq.logprobs if seq.logprobs else [] response_text = tokenizer.decode(response_tokens, skip_special_tokens=True) - - # Build trajectory with response - messages = [] - for msg in prompt.get('messages', []): - # Skip empty assistant placeholder - if msg.get('role') == 'assistant' and not msg.get('content', '').strip(): - continue - messages.append(msg) + + # Trajectory (for reward computation only) + messages = [ + msg for msg in prompt.get('messages', []) + if not (msg.get('role') == 'assistant' + and not msg.get('content', '').strip()) + ] messages.append(Message(role='assistant', content=response_text)) - - # Copy user_data from prompt for reward computation - traj = Trajectory(messages=messages, user_data=prompt.get('user_data', [])) - results.append((traj, response_logprobs, len(response_tokens))) - - return results + trajectories.append(Trajectory( + messages=messages, + user_data=prompt.get('user_data', []), + )) + + # InputFeature (exact token alignment with sampler) + input_ids = prompt_ids + response_tokens + labels = [-100] * len(prompt_ids) + response_tokens + input_feature = InputFeature( + input_ids=np.array(input_ids), + labels=np.array(labels), + ) + input_feature = template._invoke_post_pipeline([input_feature]) + input_features.append(input_feature[0]) + + old_logps_list.append(response_logprobs) + completion_lengths.append(len(response_tokens)) + + return trajectories, input_features, old_logps_list, completion_lengths + + +def compute_rewards(trajectories: List[Trajectory]) -> Tuple[List[float], List[float], List[float]]: + """Compute format and accuracy rewards.""" + total_rewards, format_rewards, accuracy_rewards = [], [], [] + for traj in trajectories: + messages = traj.get('messages', []) + completion = "" + for msg in reversed(messages): + if msg.get('role') == 'assistant': + completion = msg.get('content', '') + break + user_data = traj.get('user_data', [{}]) + data = user_data[0] if isinstance(user_data, list) and user_data else {} + target = data.get('target', 0) + nums = data.get('nums', []) + fmt_reward = format_reward(completion) + acc_reward = countdown_accuracy_reward(completion, target, nums) + format_rewards.append(fmt_reward) + accuracy_rewards.append(acc_reward) + total_rewards.append(fmt_reward + acc_reward) + return total_rewards, format_rewards, accuracy_rewards def wait_result(result): @@ -292,185 +262,15 @@ def wait_result(result): return result -def log(msg): - """Print message with timestamp.""" - import datetime - ts = datetime.datetime.now().strftime("%H:%M:%S") - print(f"[{ts}] {msg}", flush=True) - - -def _collect_sample_responses(results): - """Custom collect function to merge multiple SampleResponse objects from DP workers. - - Args: - results: List of SampleResponse from each DP worker. - - Returns: - Merged SampleResponse with all sequences combined. - """ - if not results: - return SampleResponse(sequences=[]) - - if len(results) == 1: - return results[0] - - all_sequences = [] - for resp in results: - if resp is not None and hasattr(resp, 'sequences'): - all_sequences.extend(resp.sequences) - - return SampleResponse(sequences=all_sequences) - - -# ========== Hybrid Actor ========== -@remote_class() -class HybridModelSamplerActor: - """Hybrid actor that fuses training model and sampler in same process. - - This simulates the Hybrid mode where: - - Training model (TransformersModel) holds the real weights - - vLLM Sampler starts with dummy/random weights - - Weight sync happens via IPCWeightLoader (CUDA IPC + ZMQ) - """ - - def __init__( - self, - model_id: str, - device_mesh: DeviceMesh = None, - lora_config = None, - adapter_name: str = 'default', - learning_rate: float = 1e-5, - gradient_accumulation_steps: int = 8, - epsilon: float = 0.2, - beta: float = 0.0, - remote_group: str = None, - **kwargs - ): - import torch - rank = torch.cuda.current_device() if torch.cuda.is_available() else 0 - log(f"[Rank {rank}] Initializing HybridModelSamplerActor...") - - self.adapter_name = adapter_name - self.model_id = model_id - self.lora_config = lora_config # Store for weight sync - - # Initialize sampler with real model weights (not dummy) - # For LoRA training, vLLM loads base model weights, then we sync LoRA weights - self.sampler = VLLMSampler( - model_id=model_id, - engine_args={ - 'gpu_memory_utilization': 0.4, - 'max_model_len': 2048, - 'enforce_eager': True, - 'enable_sleep_mode': True, - # Enable LoRA in vLLM - 'enable_lora': True, - 'max_lora_rank': 64, - }, - ) - self.sampler.set_template(Template, model_id=model_id) - log(f"[Rank {rank}] VLLMSampler initialized with real base weights") - - # Initialize training model with real weights - self.model = TransformersModel(model_id=model_id, device_mesh=device_mesh) - log(f"[Rank {rank}] TransformersModel initialized with real weights") - - # Add LoRA adapter - if lora_config is not None: - self.model.add_adapter_to_model(adapter_name, lora_config, - gradient_accumulation_steps=gradient_accumulation_steps) - - # Set optimizer - self.model.set_optimizer('AdamW', lr=learning_rate, adapter_name=adapter_name) - - # Set lr scheduler - use LinearLR for simplicity - self.model.set_lr_scheduler('LinearLR', adapter_name=adapter_name) - - # Set loss - self.model.set_loss('GRPOLoss', adapter_name=adapter_name, epsilon=epsilon, beta=beta) - - # Set processor - self.model.set_processor(InputProcessor, adapter_name=adapter_name) - - # Set template - self.model.set_template('Template', model_id=model_id, adapter_name=adapter_name) - - log(f"[Rank {rank}] Model configured with LoRA, optimizer, scheduler, loss") - - # Initialize weight loader for Hybrid mode (CUDA IPC) - self.weight_loader = IPCWeightLoader( - model=self.model, - sampler=self.sampler, - bucket_size_mb=512, - ) - log(f"[Rank {rank}] IPCWeightLoader initialized") - - @remote_function(dispatch='slice_dp', collect=_collect_sample_responses, lazy_collect=False) - def sample(self, batch, sampling_params: SamplingParams, num_samples: int = 1): - """Sample from the model.""" - return self.sampler.sample(batch, sampling_params, num_samples=num_samples) - - @remote_function() - def wake_up(self): - """Wake up the sampler.""" - self.sampler.wake_up() - - @remote_function() - def sleep(self): - """Put the sampler to sleep.""" - self.sampler.sleep() - - @remote_function() - def load_weights(self): - """Sync LoRA weights from model to sampler. - - Since vLLM loads base model weights during initialization (not using load_format='dummy'), - we only need to sync LoRA weights with base_sync_done=True. - """ - from dataclasses import asdict - peft_config = asdict(self.lora_config) if self.lora_config else None - # base_sync_done=True: vLLM has base model, only sync LoRA weights - self.weight_loader.load_weights( - adapter_name=self.adapter_name, - peft_config=peft_config, - base_sync_done=True, - ) - - @remote_function() - def forward_backward(self, inputs, trajectories=None, old_logps=None, **kwargs): - """Forward and backward pass.""" - return self.model.forward_backward( - inputs=inputs, - adapter_name=self.adapter_name, - trajectories=trajectories, - old_logps=old_logps, - **kwargs, - ) - - @remote_function() - def clip_grad_and_step(self): - """Clip gradients and step optimizer.""" - return self.model.clip_grad_and_step(adapter_name=self.adapter_name) - - @remote_function() - def get_train_configs(self): - """Get training configs.""" - return self.model.get_train_configs(adapter_name=self.adapter_name) - - @remote_function() - def save(self, path: str): - """Save model checkpoint.""" - self.model.save(path, adapter_name=self.adapter_name) - - # ========== Main ========== def main(): - # SwanLab setup (optional) if USE_SWANLAB: swanlab.login(api_key=os.environ['SWANLAB_API_KEY'], save=True) swanlab.init(project="ms-swift", config={ 'model_id': MODEL_ID, 'num_gpus': NUM_GPUS, + 'model_gpus': MODEL_GPUS, + 'sampler_gpus': SAMPLER_GPUS, 'num_generations': NUM_GENERATIONS, 'learning_rate': LEARNING_RATE, 'grpo_beta': GRPO_BETA, @@ -479,182 +279,161 @@ def main(): }) else: logger.info("SWANLAB_API_KEY not set, running without experiment tracking") - - # Hybrid mode: model and sampler on same GPUs + + # ── Device setup ────────────────────────────────────────────────── device_groups = [ - DeviceGroup(name='hybrid', ranks=list(range(NUM_GPUS)), device_type='GPU', gpus_per_worker=1), + DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), + device_type='GPU', gpus_per_worker=1), + DeviceGroup(name='sampler', ranks=list(range(MODEL_GPUS, NUM_GPUS)), + device_type='GPU', gpus_per_worker=1), ] - device_mesh = DeviceMesh.from_sizes(world_size=NUM_GPUS, dp_size=NUM_GPUS) - + model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS) + sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS) + twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups) logger.info(get_device_placement()) - + lora_config = LoraConfig( - target_modules="all-linear", - r=8, - lora_alpha=32, - lora_dropout=0.05, + target_modules="all-linear", r=8, lora_alpha=32, lora_dropout=0.05, ) - - # Create hybrid actor with all configurations - hybrid_actor = HybridModelSamplerActor( - model_id=MODEL_ID, - device_mesh=device_mesh, - lora_config=lora_config, - adapter_name=ADAPTER_NAME, - learning_rate=LEARNING_RATE, + + # ── Model (training) ────────────────────────────────────────────── + model = TransformersModel( + model_id=MODEL_ID, device_mesh=model_mesh, remote_group='model', + ) + model.add_adapter_to_model( + ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, - epsilon=GRPO_EPSILON, - beta=GRPO_BETA, - remote_group='hybrid', ) - - # Log training config - train_configs = wait_result(hybrid_actor.get_train_configs()) - logger.info(f"Training configs: {train_configs}") - - # Dataset + model.set_optimizer('AdamW', lr=LEARNING_RATE, adapter_name=ADAPTER_NAME) + model.set_lr_scheduler('LinearLR', adapter_name=ADAPTER_NAME) + model.set_loss('GRPOLoss', adapter_name=ADAPTER_NAME, + epsilon=GRPO_EPSILON, beta=GRPO_BETA) + model.set_processor(InputProcessor, adapter_name=ADAPTER_NAME) + model.set_template('Template', model_id=MODEL_ID, adapter_name=ADAPTER_NAME) + + sampler = VLLMSampler( + model_id=MODEL_ID, + engine_args={ + 'load_format': 'dummy', + 'gpu_memory_utilization': 0.7, + 'max_model_len': 2048, + 'enforce_eager': True, + 'enable_sleep_mode': False, + 'enable_lora': True, + }, + device_mesh=sampler_mesh, + remote_group='sampler', + ) + sampler.set_template(Template, model_id=MODEL_ID) + + ckpt_manager = CheckpointEngineManager(model=model, sampler=sampler) dataset = create_countdown_dataset() dataloader = DataLoader( - dataset=dataset, - batch_size=BATCH_SIZE, - device_mesh=device_mesh, - remote_group='hybrid', - num_workers=0, + dataset=dataset, batch_size=BATCH_SIZE, + device_mesh=model_mesh, remote_group='model', num_workers=0, ) - - # Tokenizer model_path = HubOperation.download_model(MODEL_ID) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - - # Advantage calculator advantage_fn = GRPOAdvantage() - - # Metrics metrics = TrainingMetrics() - + sampling_params = SamplingParams( - max_tokens=MAX_NEW_TOKENS, - temperature=TEMPERATURE, - top_p=0.95, + max_tokens=MAX_NEW_TOKENS, temperature=TEMPERATURE, top_p=0.95, ) - - logger.info(f"Starting training for {MAX_STEPS} steps") - logger.info(f"Config: batch_size={BATCH_SIZE}, num_generations={NUM_GENERATIONS}, " - f"lr={LEARNING_RATE}, beta={GRPO_BETA}, epsilon={GRPO_EPSILON}") step = 0 - + for batch in dataloader: if step >= MAX_STEPS: break - + metrics.reset() - + if callable(batch): batch = batch() prompts = batch if isinstance(batch, list) else [batch] - - # ========== 1. Weight Sync (before generation) ========== + + # ========== 1. Weight Sync ========== if step % WEIGHT_SYNC_INTERVAL == 0: sync_start = time.perf_counter() - wait_result(hybrid_actor.load_weights()) + ckpt_manager.sync_weights(adapter_name=ADAPTER_NAME) metrics.weight_sync_time = time.perf_counter() - sync_start - - # ========== 2. Generate samples ========== - wait_result(hybrid_actor.wake_up()) - + + # ========== 2. Generate ========== gen_start = time.perf_counter() sample_response = wait_result( - hybrid_actor.sample(prompts, sampling_params, num_samples=NUM_GENERATIONS) + sampler.sample(prompts, sampling_params, num_samples=NUM_GENERATIONS) ) metrics.generate_time = time.perf_counter() - gen_start - - wait_result(hybrid_actor.sleep()) - + # ========== 3. Process samples ========== - samples = process_samples(prompts, sample_response, tokenizer, NUM_GENERATIONS) - - if not samples: + template = sampler._get_template(adapter_name=ADAPTER_NAME) + trajectories, input_features, old_logps_list, completion_lengths = \ + process_samples(prompts, sample_response, tokenizer, NUM_GENERATIONS, template) + + if not trajectories: logger.warning(f"Step {step}: No valid samples, skipping") step += 1 continue - - trajectories = [s[0] for s in samples] - old_logps_list = [s[1] for s in samples] - completion_lengths = [s[2] for s in samples] - + metrics.completion_lengths = completion_lengths - + # ========== 4. Compute rewards ========== total_rewards, format_rewards, accuracy_rewards = compute_rewards(trajectories) - metrics.rewards = total_rewards metrics.format_rewards = format_rewards metrics.accuracy_rewards = accuracy_rewards - - # Debug: print sample completions and rewards for first step - if step == 0: - logger.info(f"=== Debug: Step {step} sample completions ===") - for i, traj in enumerate(trajectories[:3]): # Print first 3 - messages = traj.get('messages', []) - completion = "" - for msg in reversed(messages): - if msg.get('role') == 'assistant': - completion = msg.get('content', '')[:200] # First 200 chars - break - logger.info(f"Sample {i}: format_reward={format_rewards[i]}, acc_reward={accuracy_rewards[i]}") - logger.info(f" Completion: {completion}...") - logger.info(f"Total rewards stats: mean={sum(total_rewards)/len(total_rewards):.4f}, " - f"format_mean={sum(format_rewards)/len(format_rewards):.4f}, " - f"acc_mean={sum(accuracy_rewards)/len(accuracy_rewards):.4f}") - - # ========== 5. Compute advantages and add to trajectories ========== + + # ========== 5. Compute advantages ========== advantages = advantage_fn(total_rewards, num_generations=NUM_GENERATIONS, scale='group') - - # Add advantages to trajectories (GRPOLoss extracts from trajectory['advantages']) - for i, traj in enumerate(trajectories): - traj['advantages'] = float(advantages[i]) - - # Check if all advantages are zero (frac_reward_zero_std indicator) - frac_zero_std = 1.0 if all(abs(adv) < 1e-8 for adv in advantages) else 0.0 - - # Skip if all advantages are zero + # Convert to list so dispatch='slice_dp' slices it in sync with inputs + advantages = advantages.tolist() + + frac_zero_std = 1.0 if all(abs(a) < 1e-8 for a in advantages) else 0.0 if frac_zero_std == 1.0: logger.info(f"Step {step}: All advantages are zero, skipping training") step += 1 continue - + # ========== 6. Training step ========== - loss = wait_result(hybrid_actor.forward_backward( - inputs=trajectories, - trajectories=trajectories, + # Pass InputFeature list directly (exact token alignment with sampler). + # advantages and old_logps are lists, sliced in sync by dispatch. + loss = wait_result(model.forward_backward( + inputs=input_features, + adapter_name=ADAPTER_NAME, + advantages=advantages, old_logps=old_logps_list, )) - - grad_norm = wait_result(hybrid_actor.clip_grad_and_step()) - + + grad_norm = wait_result(model.clip_grad_and_step(adapter_name=ADAPTER_NAME)) metrics.loss = float(loss) if loss else 0.0 - metrics.grad_norm = float(grad_norm) if grad_norm else 0.0 - - # ========== 7. Log metrics ========== + if isinstance(grad_norm, list): + grad_norm = grad_norm[0] + metrics.grad_norm = float(grad_norm) if isinstance(grad_norm, (int, float)) else 0.0 + + from twinkle.utils.framework import Torch + import gc + gc.collect() + Torch.empty_cache() + + # ========== 7. Log ========== log_dict = metrics.to_log_dict(step) log_dict['train/frac_reward_zero_std'] = frac_zero_std - if USE_SWANLAB: swanlab.log(log_dict) - + logger.info( - f"Step {step}: loss={metrics.loss:.6f}, grad_norm={metrics.grad_norm:.4f}, " + f"Step {step}: loss={metrics.loss:.6f}, grad_norm={metrics.grad_norm:.7f}, " f"reward={log_dict.get('train/reward', 0):.4f}, " f"format={log_dict.get('train/rewards/Format/mean', 0):.2f}, " f"accuracy={log_dict.get('train/rewards/CountdownORM/mean', 0):.2f}, " f"completion_len={log_dict.get('train/completions/mean_length', 0):.1f}" ) - + step += 1 - + logger.info(f"Training completed. Total steps: {step}") - wait_result(hybrid_actor.save('grpo-countdown-checkpoint')) + wait_result(model.save('grpo-countdown-checkpoint', adapter_name=ADAPTER_NAME)) if USE_SWANLAB: swanlab.finish() diff --git a/cookbook/legacy/grpo/lora_npu.py b/cookbook/legacy/grpo/lora_npu.py index fe3f03ff..6a19f941 100644 --- a/cookbook/legacy/grpo/lora_npu.py +++ b/cookbook/legacy/grpo/lora_npu.py @@ -315,6 +315,7 @@ def create_dataset(): def train(): + raise NotImplementedError("Not implemented") nproc_per_node, actor_ranks, ref_ranks = parse_device_config() device_groups = create_device_groups(actor_ranks, ref_ranks) diff --git a/cookbook/legacy/grpo/megatron_lora.py b/cookbook/legacy/grpo/megatron_lora.py new file mode 100644 index 00000000..1bcafaa2 --- /dev/null +++ b/cookbook/legacy/grpo/megatron_lora.py @@ -0,0 +1,428 @@ +""" +GRPO Training Cookbook - MegatronModel with LoRA (Standalone Mode) + +Tests MegatronModel RL training with the same Countdown Game task as lora.py. + +Usage: + python cookbook/grpo/megatron_lora.py +""" + +import os +import re +import time +import numpy as np +from typing import List, Dict, Any, Tuple +from dataclasses import dataclass, field + +os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" + +from peft import LoraConfig +import torch + +import twinkle +from twinkle import DeviceMesh, DeviceGroup, Platform, get_device_placement, get_logger +from twinkle import remote_class, remote_function +from twinkle.data_format import Trajectory, Message, InputFeature +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.model import MegatronModel +from twinkle.processor import InputProcessor +from twinkle.sampler import VLLMSampler +from twinkle.sampler.types import SamplingParams, SampleResponse +from twinkle.rl import GRPOAdvantage +import ray +from twinkle.template import Template + +from transformers import AutoTokenizer +from twinkle.hub import HubOperation + +logger = get_logger() + +# ========== Configuration ========== +MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen2.5-3B-Instruct') +NUM_GPUS = 4 +MODEL_GPUS = 2 +SAMPLER_GPUS = 2 +NUM_GENERATIONS = 4 +MAX_NEW_TOKENS = 1024 +LEARNING_RATE = 1e-5 +GRPO_EPSILON = 0.2 +GRPO_BETA = 0.0 +MAX_STEPS = 20 +BATCH_SIZE = 2 +GRADIENT_ACCUMULATION_STEPS = 1 +TEMPERATURE = 1.0 +WEIGHT_SYNC_INTERVAL = 1 +ADAPTER_NAME = 'default' + + +# ========== Metrics ========== +@dataclass +class TrainingMetrics: + generate_time: float = 0.0 + weight_sync_time: float = 0.0 + rewards: List[float] = field(default_factory=list) + format_rewards: List[float] = field(default_factory=list) + accuracy_rewards: List[float] = field(default_factory=list) + completion_lengths: List[int] = field(default_factory=list) + loss: float = 0.0 + grad_norm: float = 0.0 + + def reset(self): + self.generate_time = 0.0 + self.weight_sync_time = 0.0 + self.rewards = [] + self.format_rewards = [] + self.accuracy_rewards = [] + self.completion_lengths = [] + self.loss = 0.0 + self.grad_norm = 0.0 + + +# ========== Rewards ========== +def format_reward(completion: str) -> float: + has_think = bool(re.search(r".*?", completion, re.DOTALL)) + has_answer = bool(re.search(r".*?", completion, re.DOTALL)) + return 1.0 if (has_think and has_answer) else 0.0 + + +def countdown_accuracy_reward(completion: str, target: int, nums: List[int]) -> float: + try: + match = re.search(r'(.*?)<\/answer>', completion) + if match is None: + return 0.0 + equation = match.group(1).strip() + if '=' in equation: + equation = equation.split('=')[0] + used_numbers = [int(n) for n in re.findall(r'\d+', equation)] + if sorted(used_numbers) != sorted(nums): + return 0.0 + if not re.match(r'^[\d+\-*/().\s]+$', equation): + return 0.0 + result = eval(equation, {'__builtins__': None}, {}) + return 1.0 if abs(float(result) - float(target)) < 1e-5 else 0.0 + except Exception: + return 0.0 + + +# ========== Dataset ========== +def create_countdown_dataset(): + from twinkle.preprocessor import CountdownProcessor + dataset = Dataset(DatasetMeta("ms://zouxuhong/Countdown-Tasks-3to4", data_slice=range(50000))) + dataset.set_template("Template", model_id=MODEL_ID, max_length=8192) + dataset.map(CountdownProcessor()) + return dataset + + +# ========== Sample Processing ========== +def process_samples( + prompts: List[Trajectory], + sample_response: SampleResponse, + tokenizer, + num_generations: int, + template: Template, +) -> Tuple[List[Trajectory], List[InputFeature], List[List[float]], List[int]]: + """Process sampled responses — same logic as lora.py.""" + trajectories: List[Trajectory] = [] + input_features: List[InputFeature] = [] + old_logps_list: List[List[float]] = [] + completion_lengths: List[int] = [] + + sequences = sample_response.sequences + prompt_ids_cache: Dict[int, List[int]] = {} + + for i, prompt in enumerate(prompts): + if i not in prompt_ids_cache: + prompt_messages = [ + dict(msg) for msg in prompt.get('messages', []) + if not (msg.get('role') == 'assistant' + and not msg.get('content', '').strip()) + ] + encoded = tokenizer.apply_chat_template( + prompt_messages, tokenize=True, add_generation_prompt=True, + ) + if hasattr(encoded, 'tolist'): + encoded = encoded.tolist() + prompt_ids_cache[i] = list(encoded) + + prompt_ids = prompt_ids_cache[i] + + for j in range(num_generations): + seq_idx = i * num_generations + j + if seq_idx >= len(sequences): + break + + seq = sequences[seq_idx] + response_tokens = list(seq.tokens) + response_logprobs = seq.logprobs if seq.logprobs else [] + response_text = tokenizer.decode(response_tokens, skip_special_tokens=True) + + messages = [ + msg for msg in prompt.get('messages', []) + if not (msg.get('role') == 'assistant' + and not msg.get('content', '').strip()) + ] + messages.append(Message(role='assistant', content=response_text)) + trajectories.append(Trajectory( + messages=messages, + user_data=prompt.get('user_data', []), + )) + + input_ids = prompt_ids + response_tokens + labels = [-100] * len(prompt_ids) + response_tokens + input_feature = InputFeature( + input_ids=np.array(input_ids), + labels=np.array(labels), + ) + input_feature = template._invoke_post_pipeline([input_feature]) + input_features.append(input_feature[0]) + + old_logps_list.append(response_logprobs) + completion_lengths.append(len(response_tokens)) + + return trajectories, input_features, old_logps_list, completion_lengths + + +def compute_rewards(trajectories: List[Trajectory]) -> Tuple[List[float], List[float], List[float]]: + total_rewards, format_rewards, accuracy_rewards = [], [], [] + for traj in trajectories: + messages = traj.get('messages', []) + completion = "" + for msg in reversed(messages): + if msg.get('role') == 'assistant': + completion = msg.get('content', '') + break + user_data = traj.get('user_data', [{}]) + data = user_data[0] if isinstance(user_data, list) and user_data else {} + target = data.get('target', 0) + nums = data.get('nums', []) + fmt_reward = format_reward(completion) + acc_reward = countdown_accuracy_reward(completion, target, nums) + format_rewards.append(fmt_reward) + accuracy_rewards.append(acc_reward) + total_rewards.append(fmt_reward + acc_reward) + return total_rewards, format_rewards, accuracy_rewards + + +def wait_result(result): + if hasattr(result, '_is_lazy_collect') and result._is_lazy_collect: + return result() + if callable(result) and hasattr(result, '_get_result'): + return result() + return result + + +class SimpleWeightSync: + """Sync weights from MegatronModel to VLLMSampler via Ray object store. + + Avoids ray.util.collective NCCL, which conflicts with Megatron's + torch.distributed NCCL (Megatron's initialize_model_parallel creates + NCCL communicators that are incompatible with cupy's NCCL bindings + used by ray.util.collective). + """ + + def __init__(self, model, sampler, adapter_name: str = ''): + self.model = model + self.sampler = sampler + self.adapter_name = adapter_name + self.base_sync_done = False + + def sync_weights(self, adapter_name: str = ''): + """Sync model weights to sampler via Ray object store.""" + adapter_name = adapter_name or self.adapter_name + + if not self.base_sync_done: + # First sync: all base weights + weights_dict = wait_result( + self.model.export_weights_dict(adapter_name=adapter_name) + ) + peft_config = None + else: + # Subsequent syncs: LoRA weights only + weights_dict = wait_result( + self.model.export_weights_dict(adapter_name=adapter_name, lora_only=True) + ) + peft_config = wait_result( + self.model.get_peft_config_dict(adapter_name=adapter_name) + ) + + # Load into sampler + wait_result(self.sampler.import_weights_dict( + weights=weights_dict, + peft_config=peft_config, + base_sync_done=self.base_sync_done, + )) + + # TODO: remove this after lora sync is implemented + # if not self.base_sync_done: + # self.base_sync_done = True + + +# ========== Main ========== +def main(): + # ── Device setup ────────────────────────────────────────────────── + device_groups = [ + DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), + device_type='GPU', gpus_per_worker=1), + DeviceGroup(name='sampler', ranks=list(range(MODEL_GPUS, NUM_GPUS)), + device_type='GPU', gpus_per_worker=1), + ] + # MegatronModel: DP=2, TP=1, PP=1 for 2 GPUs + model_mesh = DeviceMesh.from_sizes( + dp_size=MODEL_GPUS, tp_size=1, pp_size=1, + ) + sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS) + + twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups) + logger.info(get_device_placement()) + + lora_config = LoraConfig( + target_modules="all-linear", r=8, lora_alpha=32, lora_dropout=0.05, + ) + + # ── MegatronModel (training) ────────────────────────────────────── + model = MegatronModel( + model_id=MODEL_ID, + device_mesh=model_mesh, + remote_group='model', + mixed_precision='bf16', + recompute_granularity='selective', + ) + model.add_adapter_to_model( + ADAPTER_NAME, lora_config, + gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, + ) + # MegatronModel uses Megatron's distributed optimizer and scheduler + model.set_optimizer('default', lr=LEARNING_RATE, adapter_name=ADAPTER_NAME) + model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, max_lr=LEARNING_RATE, + adapter_name=ADAPTER_NAME) + model.set_loss('GRPOLoss', adapter_name=ADAPTER_NAME, + epsilon=GRPO_EPSILON, beta=GRPO_BETA) + model.set_processor(InputProcessor, adapter_name=ADAPTER_NAME) + model.set_template('Template', model_id=MODEL_ID, adapter_name=ADAPTER_NAME) + + sampler = VLLMSampler( + model_id=MODEL_ID, + engine_args={ + 'load_format': 'dummy', + 'gpu_memory_utilization': 0.3, + 'max_model_len': 2048, + 'enforce_eager': True, + 'enable_sleep_mode': False, + 'enable_lora': False, # sync lora todo + }, + device_mesh=sampler_mesh, + remote_group='sampler', + ) + sampler.set_template(Template, model_id=MODEL_ID) + + # Use SimpleWeightSync instead of CheckpointEngineManager to avoid + # NCCL conflict between Megatron's torch.distributed and cupy NCCL. + weight_sync = SimpleWeightSync(model, sampler, adapter_name=ADAPTER_NAME) + dataset = create_countdown_dataset() + dataloader = DataLoader( + dataset=dataset, batch_size=BATCH_SIZE, + device_mesh=model_mesh, remote_group='model', num_workers=0, + ) + model_path = HubOperation.download_model(MODEL_ID) + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + advantage_fn = GRPOAdvantage() + metrics = TrainingMetrics() + + sampling_params = SamplingParams( + max_tokens=MAX_NEW_TOKENS, temperature=TEMPERATURE, top_p=0.95, + ) + step = 0 + + for batch in dataloader: + if step >= MAX_STEPS: + break + + metrics.reset() + + if callable(batch): + batch = batch() + prompts = batch if isinstance(batch, list) else [batch] + + # ========== 1. Weight Sync ========== + if step % WEIGHT_SYNC_INTERVAL == 0: + sync_start = time.perf_counter() + weight_sync.sync_weights(adapter_name=ADAPTER_NAME) + metrics.weight_sync_time = time.perf_counter() - sync_start + + # ========== 2. Generate ========== + gen_start = time.perf_counter() + sample_response = wait_result( + sampler.sample(prompts, sampling_params, num_samples=NUM_GENERATIONS) + ) + metrics.generate_time = time.perf_counter() - gen_start + + # ========== 3. Process samples ========== + template = sampler._get_template(adapter_name=ADAPTER_NAME) + trajectories, input_features, old_logps_list, completion_lengths = \ + process_samples(prompts, sample_response, tokenizer, NUM_GENERATIONS, template) + + if not trajectories: + logger.warning(f"Step {step}: No valid samples, skipping") + step += 1 + continue + + metrics.completion_lengths = completion_lengths + + # ========== 4. Compute rewards ========== + total_rewards, format_rewards, accuracy_rewards = compute_rewards(trajectories) + metrics.rewards = total_rewards + metrics.format_rewards = format_rewards + metrics.accuracy_rewards = accuracy_rewards + + # ========== 5. Compute advantages ========== + advantages = advantage_fn(total_rewards, num_generations=NUM_GENERATIONS, scale='group') + advantages = advantages.tolist() + + frac_zero_std = 1.0 if all(abs(a) < 1e-8 for a in advantages) else 0.0 + if frac_zero_std == 1.0: + logger.info(f"Step {step}: All advantages are zero, skipping training") + step += 1 + continue + + # ========== 6. Training step ========== + # MegatronModel.forward_backward returns float loss directly + loss = wait_result(model.forward_backward( + inputs=input_features, + adapter_name=ADAPTER_NAME, + advantages=advantages, + old_logps=old_logps_list, + )) + + # MegatronModel: step/zero_grad/lr_step separately + # step() stores grad_norm internally + wait_result(model.step(adapter_name=ADAPTER_NAME)) + wait_result(model.zero_grad(adapter_name=ADAPTER_NAME)) + wait_result(model.lr_step(adapter_name=ADAPTER_NAME)) + + metrics.loss = float(loss) if loss is not None else 0.0 + # grad_norm is not directly returned; it's stored in optimizer_config + # For now, log loss only; grad_norm can be retrieved if needed + metrics.grad_norm = 0.0 + + import gc + from twinkle.utils.framework import Torch + gc.collect() + Torch.empty_cache() + + # ========== 7. Log ========== + logger.info( + f"Step {step}: loss={metrics.loss:.6f}, grad_norm={metrics.grad_norm:.7f}, " + f"reward={sum(metrics.rewards) / max(len(metrics.rewards), 1):.4f}, " + f"format={sum(metrics.format_rewards) / max(len(metrics.format_rewards), 1):.2f}, " + f"accuracy={sum(metrics.accuracy_rewards) / max(len(metrics.accuracy_rewards), 1):.2f}, " + f"completion_len={sum(metrics.completion_lengths) / max(len(metrics.completion_lengths), 1):.1f}" + ) + + step += 1 + + logger.info(f"Training completed. Total steps: {step}") + + +if __name__ == '__main__': + main() diff --git a/src/twinkle/checkpoint_engine/README.md b/src/twinkle/checkpoint_engine/README.md new file mode 100644 index 00000000..e69de29b diff --git a/src/twinkle/checkpoint_engine/__init__.py b/src/twinkle/checkpoint_engine/__init__.py new file mode 100644 index 00000000..3c9f50c8 --- /dev/null +++ b/src/twinkle/checkpoint_engine/__init__.py @@ -0,0 +1,38 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Checkpoint Engine for weight synchronization between trainer and rollout. + +Provides NCCL/HCCL-based weight broadcast from training model workers to +inference sampler workers in STANDALONE (disaggregated) deployment mode. + +Reference: https://github.com/volcengine/verl/tree/main/verl/checkpoint_engine + +Usage: + >>> from twinkle.checkpoint_engine import CheckpointEngineManager + >>> + >>> manager = CheckpointEngineManager(model=model, sampler=sampler) + >>> manager.sync_weights() # blocking call +""" + +from .base import ( + CheckpointEngine, + CheckpointEngineRegistry, + ColocatedCheckpointEngine, + TensorMeta, +) +from .manager import CheckpointEngineManager +from .mixin import CheckpointEngineMixin + +# Import backend implementations to register them +from .nccl_checkpoint_engine import NCCLCheckpointEngine +from .hccl_checkpoint_engine import HCCLCheckpointEngine + +__all__ = [ + "CheckpointEngine", + "CheckpointEngineRegistry", + "CheckpointEngineMixin", + "ColocatedCheckpointEngine", + "CheckpointEngineManager", + "NCCLCheckpointEngine", + "HCCLCheckpointEngine", + "TensorMeta", +] diff --git a/src/twinkle/checkpoint_engine/base.py b/src/twinkle/checkpoint_engine/base.py new file mode 100644 index 00000000..2a3addfa --- /dev/null +++ b/src/twinkle/checkpoint_engine/base.py @@ -0,0 +1,254 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +# Adapted from https://github.com/volcengine/verl/blob/main/verl/checkpoint_engine/base.py +"""Base classes for checkpoint engine. + +CheckpointEngine is an abstraction layer to synchronize weights between +trainer and rollout. It provides unified APIs: +- send_weights: Get named tensors from generator and send them in streaming manner. +- receive_weights: Return a tensor generator that yields named tensors in streaming manner. +""" + +import logging +from abc import ABC, abstractmethod +from typing import Any, AsyncGenerator, Generator, TypedDict + +import torch + +logger = logging.getLogger(__name__) + + +class TensorMeta(TypedDict): + """Metadata for a tensor in the weight bucket.""" + name: str + shape: torch.Size + dtype: torch.dtype + offset: int + + +class CheckpointEngineRegistry: + """Registry for checkpoint engine backends.""" + + _registry: dict[str, type["CheckpointEngine"]] = {} + + @classmethod + def register(cls, backend: str): + """Register a checkpoint engine backend. + + Args: + backend: The backend name (e.g., 'naive', 'nccl', 'hccl'). + """ + def wrapper(engine_cls: type["CheckpointEngine"]): + cls._registry[backend] = engine_cls + return engine_cls + return wrapper + + @classmethod + def get(cls, backend: str) -> type["CheckpointEngine"]: + """Get the checkpoint engine class by backend name. + + Args: + backend: The backend name. + + Returns: + The checkpoint engine class. + """ + if backend not in cls._registry: + raise ValueError(f"Checkpoint engine '{backend}' not registered. " + f"Available backends: {list(cls._registry.keys())}") + return cls._registry[backend] + + @classmethod + def new(cls, backend: str, *args, **kwargs) -> "CheckpointEngine": + """Create a new checkpoint engine instance. + + Args: + backend: The backend name. + *args: Positional arguments for the engine constructor. + **kwargs: Keyword arguments for the engine constructor. + + Returns: + A new checkpoint engine instance. + """ + return cls.get(backend)(*args, **kwargs) + + +class CheckpointEngine(ABC): + """Abstract base class for checkpoint engines. + + A checkpoint engine handles weight synchronization between trainer and rollout + processes. The typical workflow is: + + In trainer process (rank 0): + >>> engine = CheckpointEngineRegistry.new('nccl', bucket_size=512<<20) + >>> engine.is_master = True # set before prepare() + >>> engine.prepare() + >>> engine.init_process_group(rank=0, world_size=5, master_metadata=metadata) + >>> await engine.send_weights(weight_generator()) + >>> engine.finalize() + + In rollout process: + >>> engine = CheckpointEngineRegistry.new('nccl', bucket_size=512<<20) + >>> engine.prepare() + >>> engine.init_process_group(rank=1, world_size=5, master_metadata=metadata) + >>> async for name, tensor in engine.receive_weights(): + ... weights.append((name, tensor)) + >>> engine.finalize() + """ + + @abstractmethod + def prepare(self) -> dict[str, Any]: + """Prepare the checkpoint engine before weight synchronization. + + This method should: + 1. Allocate weight transfer buffers. + 2. Setup communication channels (e.g., ZMQ sockets). + 3. Return metadata needed for topology building. + + Returns: + A dictionary containing metadata (e.g., master IP and port). + """ + raise NotImplementedError + + @classmethod + @abstractmethod + def build_topology( + cls, + trainer_world_size: int, + rollout_world_size: int, + metadata: list[dict], + ) -> tuple[dict[str, list[Any]], dict[str, list[Any]]]: + """Build communication topology between trainer and rollout workers. + + This method determines the rank assignment for each worker in the + temporary NCCL/HCCL process group used for weight synchronization. + + Args: + trainer_world_size: Number of trainer workers. + rollout_world_size: Number of rollout workers. + metadata: List of metadata from all workers' prepare() calls. + + Returns: + A tuple of (trainer_kwargs, rollout_kwargs), where each dict + contains lists of arguments to pass to init_process_group(). + Keys typically include: 'rank', 'world_size', 'master_metadata'. + """ + raise NotImplementedError + + @abstractmethod + def init_process_group(self, **kwargs): + """Initialize the process group for weight synchronization. + + Args: + **kwargs: Arguments from build_topology(), typically including: + - rank: The rank of this worker in the sync group. + - world_size: Total number of workers in the sync group. + - master_metadata: Metadata from the master (trainer rank 0). + """ + raise NotImplementedError + + @abstractmethod + def finalize(self): + """Finalize the checkpoint engine after weight synchronization. + + This method should: + 1. Free weight transfer buffers. + 2. Destroy the temporary process group (if rebuild_group=True). + 3. Clean up communication channels. + """ + raise NotImplementedError + + @abstractmethod + async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]): + """Send model weights to rollout workers. + + This method streams weights in buckets to avoid memory issues with + large models. Only trainer rank 0 actually sends weights; other + trainer ranks consume the generator without sending. + + Args: + weights: A generator yielding (name, tensor) pairs. + """ + raise NotImplementedError + + @abstractmethod + async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None]: + """Receive model weights from trainer. + + This method receives weights in buckets and yields them as they + become available, enabling streaming weight loading. + + Yields: + Tuples of (name, tensor) for each weight. + """ + raise NotImplementedError + + +@CheckpointEngineRegistry.register("naive") +class ColocatedCheckpointEngine(CheckpointEngine): + """Checkpoint engine for colocated trainer and rollout on same GPU. + + This is a simple pass-through engine that directly shares the weight + generator between trainer and rollout without network transfer. + It's used for Hybrid mode where trainer and rollout share the same GPU. + + Usage: + >>> engine = ColocatedCheckpointEngine(bucket_size=512<<20) + >>> engine.send_weights(model.get_hf_state_dict()) + >>> for name, tensor in engine.receive_weights(): + ... weights.append((name, tensor)) + """ + + def __init__(self, bucket_size: int, is_master: bool = False, **kwargs) -> None: + """Initialize the colocated checkpoint engine. + + Args: + bucket_size: Size of the transfer bucket in bytes (not used but kept for API compatibility). + is_master: Whether this is the master process (not used). + """ + self.bucket_size = bucket_size + self.is_master = is_master + self.weights = None + + def prepare(self) -> dict[str, Any]: + """No preparation needed for colocated mode.""" + return {} + + @classmethod + def build_topology( + cls, + trainer_world_size: int, + rollout_world_size: int, + metadata: list[dict], + ) -> tuple[dict[str, list[Any]], dict[str, list[Any]]]: + """No topology building needed for colocated mode.""" + return {}, {} + + def init_process_group(self, **kwargs): + """No process group needed for colocated mode.""" + pass + + def finalize(self): + """No cleanup needed for colocated mode.""" + self.weights = None + + def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]): + """Store the weights generator for later retrieval. + + Note: This is a synchronous method since no network transfer is needed. + + Args: + weights: A generator yielding (name, tensor) pairs. + """ + self.weights = weights + + def receive_weights(self) -> Generator[tuple[str, torch.Tensor], None, None]: + """Retrieve the stored weights generator. + + Note: This is a synchronous method since no network transfer is needed. + + Yields: + Tuples of (name, tensor) from the stored generator. + """ + if self.weights is not None: + yield from self.weights + self.weights = None diff --git a/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py b/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py new file mode 100644 index 00000000..e1078e2e --- /dev/null +++ b/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py @@ -0,0 +1,438 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +# Adapted from https://github.com/volcengine/verl/blob/main/verl/checkpoint_engine/hccl_checkpoint_engine.py +"""HCCL-based checkpoint engine for Ascend NPU. + +This engine uses HCCL broadcast for efficient NPU-to-NPU weight transfer +across different processes/nodes. It supports: +- Double buffering for pipelined transfer +- ZMQ for metadata, HCCL for weight data +- Streaming weight transfer to avoid OOM +""" + +import asyncio +import logging +import time +from dataclasses import dataclass +from typing import Any, AsyncGenerator, Generator + +import torch +import zmq + +from twinkle.utils.network import ( + find_node_ip, + find_free_port, + is_valid_ipv6_address, + stateless_init_process_group, +) +from .base import CheckpointEngine, CheckpointEngineRegistry, TensorMeta + +logger = logging.getLogger(__name__) + + +@dataclass +class MasterMetadata: + """Metadata from the master for process group initialization.""" + zmq_ip: str + zmq_port: int + dist_ip: str + dist_port: int + + +class BroadcastOperation: + """Async broadcast operation with HCCL in separate thread. + + Args: + rank: The rank of the current process. + process_group: The HCCL process group. + bucket: The tensor buffer to broadcast. + metadata: The metadata of tensors in the bucket. + socket: The ZMQ socket for metadata communication. + topic: The ZMQ topic for pub/sub. + """ + + def __init__( + self, + rank: int, + process_group, + bucket: torch.Tensor, + metadata: dict[str, TensorMeta], + socket: zmq.Socket, + topic: str, + ) -> None: + self.rank = rank + self.pyhccl = process_group + self.bucket = bucket + self.metadata = metadata + self.socket = socket + self.topic = topic + + loop = asyncio.get_running_loop() + self._task = loop.run_in_executor(None, self._run) + + def _run(self): + """Execute the broadcast operation in a thread.""" + # Broadcast tensor metadata via ZMQ PUB/SUB + if self.rank == 0: + self.socket.send_string(self.topic, flags=zmq.SNDMORE) + self.socket.send_pyobj(self.metadata) + else: + self.socket.recv_string() + self.metadata = self.socket.recv_pyobj() + + # Broadcast tensor data via HCCL + self.pyhccl.broadcast(self.bucket, src=0) + + async def wait_for_complete(self) -> dict[str, TensorMeta]: + """Wait for the broadcast operation to complete. + + Returns: + The bucket metadata after broadcast. + """ + await self._task + return self.metadata + + +@CheckpointEngineRegistry.register("hccl") +class HCCLCheckpointEngine(CheckpointEngine): + """HCCL checkpoint engine for Ascend NPU. + + Same lifecycle and semantics as NCCLCheckpointEngine but uses HCCL + instead of NCCL and stateless_init_process_group instead of + ray.util.collective. + + Args: + bucket_size: Bucket size in bytes for weight transfer. + group_name: Name of the process group. + rebuild_group: Whether to rebuild the group each sync. + rollout_dtype: Target dtype for weights. + """ + + def __init__( + self, + bucket_size: int, + group_name: str = "twinkle_ckpt", + rebuild_group: bool = True, + rollout_dtype: torch.dtype = torch.bfloat16, + **kwargs, + ) -> None: + self.bucket_size = bucket_size + self.group_name = group_name + self.rebuild_group = rebuild_group + self.rollout_dtype = rollout_dtype + self.pyhccl = None + + # Get current NPU device + try: + self.device = torch.npu.current_device() + except Exception: + self.device = 0 + + # Set by Manager before prepare() via attribute assignment + self.is_master = False + self.topic = "bucket_metadata" + + # Will be set during prepare / init_process_group + self.rank = None + self.world_size = None + self.send_buf = None + self.recv_buf = None + self.socket = None + + # Track whether resources are ready for reuse + self._prepared = False + self._group_initialized = False + + # ── ZMQ helpers ────────────────────────────────────────────────────── + + def _start_zmq_server(self): + """Start ZMQ PUB server for metadata broadcast (master only).""" + self.ip = find_node_ip() + self.zmq_port = find_free_port() + self.dist_port = find_free_port() + + context = zmq.Context() + self.socket = context.socket(zmq.PUB) + if is_valid_ipv6_address(self.ip): + address = f"tcp://[{self.ip}]:{self.zmq_port}" + self.socket.setsockopt(zmq.IPV6, 1) + else: + address = f"tcp://{self.ip}:{self.zmq_port}" + + self.socket.bind(address) + logger.debug(f"ZMQ PUB server started at {address}") + + def _connect_zmq_client(self, metadata: MasterMetadata): + """Connect to the ZMQ PUB server as a subscriber (receiver only).""" + context = zmq.Context() + self.socket = context.socket(zmq.SUB) + if is_valid_ipv6_address(metadata.zmq_ip): + address = f"tcp://[{metadata.zmq_ip}]:{metadata.zmq_port}" + self.socket.setsockopt(zmq.IPV6, 1) + else: + address = f"tcp://{metadata.zmq_ip}:{metadata.zmq_port}" + + self.socket.connect(address) + self.socket.setsockopt_string(zmq.SUBSCRIBE, self.topic) + logger.debug(f"ZMQ SUB client connected to {address}") + + # ── Core lifecycle ─────────────────────────────────────────────────── + + def prepare(self) -> MasterMetadata | None: + """Allocate double buffers and start ZMQ server (master only). + + Idempotent: skips if already prepared. + + Returns: + MasterMetadata with ZMQ/dist IP/port if master, else None. + """ + if self._prepared: + if self.is_master: + return MasterMetadata( + zmq_ip=self.ip, zmq_port=self.zmq_port, + dist_ip=self.ip, dist_port=self.dist_port, + ) + return None + + self.send_buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device="npu") + self.recv_buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device="npu") + + if self.is_master: + self._start_zmq_server() + self._prepared = True + return MasterMetadata( + zmq_ip=self.ip, zmq_port=self.zmq_port, + dist_ip=self.ip, dist_port=self.dist_port, + ) + self._prepared = True + return None + + def finalize(self): + """Clean up resources after a sync. + + When ``rebuild_group=False``: keeps everything alive for reuse. + When ``rebuild_group=True``: full teardown. + """ + if self.rebuild_group: + if self.socket is not None: + try: + self.socket.close() + except Exception as e: + logger.warning(f"Error closing ZMQ socket: {e}") + self.socket = None + + if self.rank is not None and self.rank >= 0 and self.pyhccl is not None: + try: + self.pyhccl.destroyComm(self.pyhccl.comm) + except Exception: + pass + self.pyhccl = None + + self.rank = None + self.world_size = None + self.send_buf = None + self.recv_buf = None + self._prepared = False + self._group_initialized = False + + @classmethod + def build_topology( + cls, + trainer_world_size: int, + rollout_world_size: int, + metadata: list[dict], + ) -> tuple[dict[str, list[Any]], dict[str, list[Any]]]: + """Build communication topology for HCCL broadcast. + + Same topology as NCCLCheckpointEngine. + """ + master_metadata = None + for m in metadata: + if m is not None: + master_metadata = m + break + + trainer_kwargs = { + "rank": [0] + [-1] * (trainer_world_size - 1), + "world_size": [rollout_world_size + 1] * trainer_world_size, + "master_metadata": [master_metadata] * trainer_world_size, + } + rollout_kwargs = { + "rank": list(range(1, rollout_world_size + 1)), + "world_size": [rollout_world_size + 1] * rollout_world_size, + "master_metadata": [master_metadata] * rollout_world_size, + } + return trainer_kwargs, rollout_kwargs + + def init_process_group(self, rank: int, world_size: int, master_metadata: MasterMetadata): + """Initialize the HCCL process group. + + Idempotent: if already initialized and ``rebuild_group`` is False, + this is a fast no-op. + + Args: + rank: The rank of this worker (-1 for non-participating trainers). + world_size: Total number of workers in the sync group. + master_metadata: Metadata from the master. + """ + # Non-participating trainer ranks + if rank < 0: + self.rank = rank + self.world_size = world_size + self._group_initialized = True + return + + # Fast path: already initialized + if self._group_initialized and not self.rebuild_group: + return + + if self.rebuild_group or self.pyhccl is None: + self.pyhccl = stateless_init_process_group( + master_address=master_metadata.dist_ip, + master_port=master_metadata.dist_port, + rank=rank, + world_size=world_size, + device=self.device, + backend="hccl", + ) + self.rank = rank + self.world_size = world_size + else: + assert self.rank == rank + assert self.world_size == world_size + + # Receivers connect to master's ZMQ PUB server + if self.rank > 0 and self.socket is None: + self._connect_zmq_client(master_metadata) + + # Barrier using all_reduce + signal = torch.tensor([1], dtype=torch.int8, device=torch.npu.current_device()) + self.pyhccl.all_reduce(signal) + + self._group_initialized = True + logger.info(f"init_process_group: rank={self.rank}, world_size={self.world_size}") + + # ── Send / Receive ─────────────────────────────────────────────────── + + @torch.no_grad() + async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]): + """Send model weights via HCCL broadcast.""" + assert self.rank is not None and self.rank <= 0 + + if self.rank < 0: + for name, weight in weights: + pass + return + + send_buf, recv_buf = self.send_buf, self.recv_buf + broadcast_op = None + + start_time = time.time() + bucket_meta: dict[str, TensorMeta] = {} + offset = 0 + + for name, weight in weights: + if offset + weight.nbytes > self.bucket_size: + torch.npu.synchronize() + + if broadcast_op is not None: + await broadcast_op.wait_for_complete() + + broadcast_op = BroadcastOperation( + rank=self.rank, + process_group=self.pyhccl, + bucket=send_buf, + metadata={"bucket_meta": bucket_meta, "is_last": False}, + socket=self.socket, + topic=self.topic, + ) + + send_buf, recv_buf = recv_buf, send_buf + bucket_meta = {} + offset = 0 + + assert offset + weight.nbytes <= self.bucket_size + + bucket_meta[name] = { + "name": name, + "shape": weight.shape, + "dtype": weight.dtype, + "offset": offset, + } + send_buf[offset:offset + weight.nbytes] = weight.view(-1).view(torch.uint8) + offset += weight.nbytes + + torch.npu.synchronize() + if broadcast_op is not None: + await broadcast_op.wait_for_complete() + + broadcast_op = BroadcastOperation( + rank=self.rank, + process_group=self.pyhccl, + bucket=send_buf, + metadata={"bucket_meta": bucket_meta, "is_last": True}, + socket=self.socket, + topic=self.topic, + ) + await broadcast_op.wait_for_complete() + + elapsed = time.time() - start_time + logger.info(f"send_weights done: rank={self.rank}, time={elapsed:.2f}s") + + @torch.no_grad() + async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None]: + """Receive model weights via HCCL broadcast.""" + assert self.rank is not None and self.rank > 0 + + send_buf, recv_buf = self.send_buf, self.recv_buf + total_bytes, total_params = 0, 0 + + start_time = time.time() + broadcast_op = BroadcastOperation( + rank=self.rank, + process_group=self.pyhccl, + bucket=recv_buf, + metadata=None, + socket=self.socket, + topic=self.topic, + ) + metadata = await broadcast_op.wait_for_complete() + total_bytes += self.bucket_size + total_params += len(metadata["bucket_meta"]) + + send_buf, recv_buf = recv_buf, send_buf + + while not metadata["is_last"]: + broadcast_op = BroadcastOperation( + rank=self.rank, + process_group=self.pyhccl, + bucket=recv_buf, + metadata=None, + socket=self.socket, + topic=self.topic, + ) + + for name, meta in metadata["bucket_meta"].items(): + dtype, shape = meta["dtype"], meta["shape"] + size = dtype.itemsize * shape.numel() + tensor = send_buf[meta["offset"]:meta["offset"] + size].view(dtype=dtype).view(shape) + yield name, tensor + + metadata = await broadcast_op.wait_for_complete() + total_bytes += self.bucket_size + total_params += len(metadata["bucket_meta"]) + + torch.npu.synchronize() + send_buf, recv_buf = recv_buf, send_buf + + for name, meta in metadata["bucket_meta"].items(): + dtype, shape = meta["dtype"], meta["shape"] + size = dtype.itemsize * shape.numel() + tensor = send_buf[meta["offset"]:meta["offset"] + size].view(dtype=dtype).view(shape) + yield name, tensor + + elapsed = time.time() - start_time + bandwidth = total_bytes / elapsed / (1024 * 1024 * 1024) + logger.info( + f"receive_weights done: rank={self.rank}, params={total_params}, " + f"time={elapsed:.2f}s, bandwidth={bandwidth:.2f} GB/s" + ) diff --git a/src/twinkle/checkpoint_engine/manager.py b/src/twinkle/checkpoint_engine/manager.py new file mode 100644 index 00000000..3e8e9d4e --- /dev/null +++ b/src/twinkle/checkpoint_engine/manager.py @@ -0,0 +1,218 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +# Adapted from https://github.com/volcengine/verl/blob/main/verl/checkpoint_engine/base.py +"""Weight synchronization manager for Twinkle (STANDALONE mode). + +Coordinates weight synchronization between training model and inference sampler +when they reside on **different GPUs** (disaggregated / standalone deployment). + +Architecture (following verl's CheckpointEngineManager): + + Trainer GPU(s) Rollout GPU(s) + ┌──────────────────┐ ┌──────────────────┐ + │ TransformersModel│ │ VLLMSampler │ + │ (Ray actors) │ │ (Ray actors) │ + │ │ │ │ │ │ + │ ▼ │ │ ▼ │ + │ CheckpointEngine │ NCCL broadcast │ CheckpointEngine │ + │ send_weights() │ ─────────────────► │ receive_weights()│ + │ │ │ │ │ + │ │ │ ▼ │ + │ │ │ VLLMEngine │ + │ │ │ update_weights()│ + │ │ │ (CUDA IPC) │ + │ │ │ │ │ + │ │ │ ▼ │ + │ │ │ vLLM subprocess │ + │ │ │ load_weights() │ + └──────────────────┘ └──────────────────┘ + +Usage: + >>> manager = CheckpointEngineManager(model=model, sampler=sampler) + >>> manager.sync_weights() # Call after each training step +""" + +import logging +import time +from typing import TYPE_CHECKING + +from .base import CheckpointEngineRegistry + +if TYPE_CHECKING: + from twinkle.model.base import TwinkleModel + from twinkle.sampler.vllm_sampler import VLLMSampler + +logger = logging.getLogger(__name__) + + +class CheckpointEngineManager: + """Coordinate weight synchronization from model to sampler via NCCL broadcast. + + This manager orchestrates a 5-step weight sync flow between Ray actors: + 1. prepare — allocate NCCL buffers, start ZMQ metadata server + 2. build_topology — assign NCCL ranks (trainer[0]→rank0, sampler→rank1..N) + 3. init_process_group — create temporary NCCL group across actors + 4. send / receive — trainer broadcasts, sampler receives (in parallel) + 5. finalize — release buffers, optionally destroy NCCL group + + LoRA-aware sync (following verl's design): + - First sync (``base_sync_done=False``): broadcasts ALL weights (base model) + so that vLLM (loaded with ``load_format='dummy'``) gets real weights. + On the sampler side, ``.base_layer`` is stripped from PEFT weight names + if ``enable_lora=False`` in vLLM, so names match vLLM's model structure. + - Subsequent syncs (``base_sync_done=True``): broadcasts ONLY LoRA adapter + weights. On the sampler side (if ``enable_lora=True``), these are loaded + via ``add_lora()`` as a tensor-based LoRA adapter. + + Args: + model: Training model with Ray actors (``model._actors``). + sampler: Inference sampler with Ray actors (``sampler._actors``). + backend: Checkpoint engine backend (``'nccl'`` or ``'hccl'``). + bucket_size_mb: Size of each weight-transfer bucket in MB. + """ + + def __init__( + self, + model: "TwinkleModel", + sampler: "VLLMSampler", + backend: str = 'nccl', + bucket_size_mb: int = 2048, + ) -> None: + self.model = model + self.sampler = sampler + self.backend = backend + self.bucket_size_mb = bucket_size_mb + self.backend_cls = CheckpointEngineRegistry.get(backend) # nccl, hccl + + # Validate Ray actors + assert hasattr(model, '_actors') and model._actors, \ + "CheckpointEngineManager requires model to be deployed as Ray actors" + assert hasattr(sampler, '_actors') and sampler._actors, \ + "CheckpointEngineManager requires sampler to be deployed as Ray actors" + + self.model_actors = model._actors + self.sampler_actors = sampler._actors + + # LoRA sync state: tracks whether the first full sync has been done. + # After the first sync, only LoRA adapter weights are transferred. + self.base_sync_done: bool = False + # Cached peft_config dict for LoRA-only sync. + # Fetched lazily from the model on first LoRA sync. + self._peft_config: dict | None = None + + logger.info( + f"CheckpointEngineManager: backend={backend}, " + f"model_workers={len(self.model_actors)}, " + f"sampler_workers={len(self.sampler_actors)}" + ) + + def sync_weights(self, adapter_name: str = ''): + """Synchronize weights from model to sampler via NCCL broadcast. + + This is a **blocking** call. It performs: + 1. prepare → allocate buffers on all workers + 2. topology → assign NCCL ranks + 3. init_pg → create temporary NCCL process group + 4. transfer → model broadcasts, sampler receives (parallel) + 5. finalize → release buffers and process group + + LoRA-aware behaviour: + - If the model has a LoRA adapter (``adapter_name`` is not empty) and + this is NOT the first sync, only LoRA weights are sent. + - The sampler side knows via ``peft_config`` / ``base_sync_done`` + whether to use ``load_weights()`` or ``add_lora()`` to apply them. + + Args: + adapter_name: Adapter name for LoRA weight sync. When non-empty + and ``base_sync_done`` is True, only LoRA weights are sent. + """ + import ray + + start_time = time.time() + is_lora_only = self.base_sync_done and bool(adapter_name) + model_actors = self.model_actors + sampler_actors = self.sampler_actors + + # ── Step 1: Prepare ────────────────────────────────────────────── + # All workers allocate buffers. Model actor[0] is designated as + # the master (is_master=True) and starts a ZMQ PUB server. + logger.debug("Step 1/5: prepare checkpoint engines") + model_prep = [ + a.prepare_checkpoint_engine.remote(is_master=(i == 0)) + for i, a in enumerate(model_actors) + ] + sampler_prep = [ + a.prepare_checkpoint_engine.remote(is_master=False) + for a in sampler_actors + ] + all_metadata = ray.get(model_prep + sampler_prep) + model_metadata = all_metadata[:len(model_actors)] + + # ── Step 2: Build topology ─────────────────────────────────────── + # trainer[0] → NCCL rank 0 (source) + # trainer[1..] → rank -1 (not participating) + # sampler[0..N-1] → NCCL rank 1..N (receivers) + logger.debug("Step 2/5: build topology") + model_kwargs, sampler_kwargs = self.backend_cls.build_topology( + len(model_actors), len(sampler_actors), model_metadata, + ) + + # ── Step 3: Init process group ─────────────────────────────────── + logger.debug("Step 3/5: init process group") + init_refs = [] + for i, actor in enumerate(model_actors): + init_refs.append(actor.init_checkpoint_process_group.remote( + rank=model_kwargs["rank"][i], + world_size=model_kwargs["world_size"][i], + master_metadata=model_kwargs["master_metadata"][i], + )) + for i, actor in enumerate(sampler_actors): + init_refs.append(actor.init_checkpoint_process_group.remote( + rank=sampler_kwargs["rank"][i], + world_size=sampler_kwargs["world_size"][i], + master_metadata=sampler_kwargs["master_metadata"][i], + )) + ray.get(init_refs) + + # ── Step 3.5: Fetch peft_config if needed for LoRA-only sync ────── + # On the first LoRA-only sync, fetch the peft_config from the model + # and cache it. This is needed by the sampler's add_lora() path. + peft_config = None + if self.base_sync_done and adapter_name: + if self._peft_config is None: + self._peft_config = ray.get( + model_actors[0].get_peft_config_dict.remote(adapter_name) + ) + peft_config = self._peft_config + + # ── Step 4: Send / Receive (parallel) ──────────────────────────── + logger.debug("Step 4/5: send & receive weights") + send_refs = [ + a.send_weights_via_checkpoint_engine.remote( + adapter_name=adapter_name, + base_sync_done=self.base_sync_done, + ) + for a in model_actors + ] + recv_refs = [ + a.receive_weights_via_checkpoint_engine.remote( + base_sync_done=self.base_sync_done, + peft_config=peft_config, + ) + for a in sampler_actors + ] + ray.get(send_refs + recv_refs) + + # ── Step 5: Finalize ───────────────────────────────────────────── + logger.debug("Step 5/5: finalize") + fin_refs = [a.finalize_checkpoint_engine.remote() for a in model_actors] + fin_refs += [a.finalize_checkpoint_engine.remote() for a in sampler_actors] + ray.get(fin_refs) + + # Mark base sync as done after first successful full sync + if not self.base_sync_done: + self.base_sync_done = True + logger.info("Base model sync completed, subsequent syncs will be LoRA-only") + + elapsed = time.time() - start_time + mode = "LoRA-only" if is_lora_only else "full" + logger.info(f"Weight sync ({mode}) completed in {elapsed:.2f}s") diff --git a/src/twinkle/checkpoint_engine/mixin.py b/src/twinkle/checkpoint_engine/mixin.py new file mode 100644 index 00000000..c4a45899 --- /dev/null +++ b/src/twinkle/checkpoint_engine/mixin.py @@ -0,0 +1,78 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""CheckpointEngineMixin — shared checkpoint engine lifecycle for Model/Sampler. + +Provides lazy-initialized checkpoint engine with prepare / init_process_group / +finalize methods. Mixed into ``TransformersModel``, ``MegatronModel``, and +``VLLMSampler`` so that the common boilerplate is written only once. + +Only activated when ``CheckpointEngineManager`` calls these methods via +``actor.method.remote()``. When weight sync is not used, the engine is +never created and has zero overhead. +""" + +import logging + +import torch + +from twinkle import remote_function + +from twinkle.checkpoint_engine.base import CheckpointEngine + +logger = logging.getLogger(__name__) + + +class CheckpointEngineMixin: + """Mixin that adds checkpoint engine lifecycle to Model/Sampler classes. + + Subclasses only need to implement the transport-specific method: + - ``send_weights_via_checkpoint_engine`` (model side) + - ``receive_weights_via_checkpoint_engine`` (sampler side) + """ + + _checkpoint_engine: "CheckpointEngine | None" = None + _checkpoint_engine_backend: str = 'nccl' + _checkpoint_engine_bucket_size: int = 2048 << 20 # 2 GB + + def _get_or_create_checkpoint_engine(self) -> "CheckpointEngine": + """Get or create the checkpoint engine instance (lazy singleton).""" + if self._checkpoint_engine is None: + if hasattr(torch, 'npu') and torch.npu.is_available(): + backend = 'hccl' + else: + backend = self._checkpoint_engine_backend + from twinkle.checkpoint_engine import CheckpointEngineRegistry + self._checkpoint_engine = CheckpointEngineRegistry.new( + backend, + bucket_size=self._checkpoint_engine_bucket_size, + ) + return self._checkpoint_engine + + @remote_function(dispatch='all') + def prepare_checkpoint_engine(self, is_master: bool = False): + """Prepare checkpoint engine and return metadata for process group setup. + + The ``CheckpointEngineManager`` calls this with ``is_master=True`` for + model actor[0] and ``is_master=False`` for all others. + + Args: + is_master: Whether this worker is the broadcast source. + """ + engine = self._get_or_create_checkpoint_engine() + engine.is_master = is_master + return engine.prepare() + + @remote_function(dispatch='all') + def init_checkpoint_process_group(self, rank: int, world_size: int, master_metadata): + """Initialize process group for weight synchronization.""" + engine = self._get_or_create_checkpoint_engine() + engine.init_process_group( + rank=rank, + world_size=world_size, + master_metadata=master_metadata, + ) + + @remote_function(dispatch='all') + def finalize_checkpoint_engine(self): + """Finalize checkpoint engine: release buffers, optionally destroy group.""" + if self._checkpoint_engine is not None: + self._checkpoint_engine.finalize() diff --git a/src/twinkle/checkpoint_engine/nccl_checkpoint_engine.py b/src/twinkle/checkpoint_engine/nccl_checkpoint_engine.py new file mode 100644 index 00000000..d6ea4c5d --- /dev/null +++ b/src/twinkle/checkpoint_engine/nccl_checkpoint_engine.py @@ -0,0 +1,568 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +# Adapted from https://github.com/volcengine/verl/blob/main/verl/checkpoint_engine/nccl_checkpoint_engine.py +"""NCCL-based checkpoint engine for disaggregated trainer and rollout. + +This engine uses NCCL broadcast for efficient GPU-to-GPU weight transfer +across different processes/nodes. It supports: +- Double buffering for pipelined transfer +- ZMQ for metadata, NCCL for weight data +- Streaming weight transfer to avoid OOM +- Persistent resources: NCCL group, ZMQ sockets, and buffers are reused + across multiple sync calls to avoid costly re-initialization (~4s per call). + +This implementation uses torch.distributed ProcessGroupNCCL directly for +NCCL operations. A dedicated TCPStore handles rendezvous between the +participating Ray actors, completely independent of any existing default +process group. This avoids NCCL version conflicts between CuPy (compiled +against NCCL 2.25) and the runtime NCCL 2.27 loaded by PyTorch. +""" + +import asyncio +import logging +import time +from dataclasses import dataclass +from typing import Any, AsyncGenerator, Generator, Union + +import ray +import torch +import torch.distributed as dist +import zmq + +from twinkle.utils.network import ( + find_free_port, + is_valid_ipv6_address, +) +from .base import CheckpointEngine, CheckpointEngineRegistry, TensorMeta + +logger = logging.getLogger(__name__) + + +@dataclass +class MasterMetadata: + zmq_ip: str + zmq_port: int + # TCPStore address for the checkpoint NCCL process group + nccl_store_host: str = "" + nccl_store_port: int = 0 + + +def _pg_broadcast(pg: dist.ProcessGroup, tensor: torch.Tensor, src: int = 0): + """Broadcast *tensor* using a raw (unregistered) ProcessGroupNCCL. + + ``dist.broadcast()`` requires a *registered* process group. Since we + create the PG directly via ``ProcessGroupNCCL(store, rank, world_size)`` + (which is NOT registered with the default ``_World``), we fall back to + the low-level C++ ``pg.broadcast([tensor], opts)`` API. + """ + opts = dist.BroadcastOptions() + opts.rootRank = src + work = pg.broadcast([tensor], opts) + work.wait() + + +class BroadcastOperation: + """Async broadcast operation with NCCL in separate thread. + + Wraps ``ProcessGroupNCCL.broadcast`` to run asynchronously so the main + thread can continue processing (e.g. filling the next bucket) while the + current bucket is being broadcast. + + Args: + rank: The rank of the current process. + pg: The torch.distributed ProcessGroup (unregistered NCCL). + bucket: The GPU tensor buffer to broadcast. + metadata: The metadata of tensors in the bucket. + socket: The ZMQ socket for metadata communication. + topic: The ZMQ topic for pub/sub. + """ + + def __init__( + self, + rank: int, + pg: dist.ProcessGroup, + bucket: torch.Tensor, + metadata: dict[str, TensorMeta], + socket: zmq.Socket, + topic: str, + ) -> None: + self.rank = rank + self.pg = pg + self.bucket = bucket + self.metadata = metadata + self.socket = socket + self.topic = topic + + loop = asyncio.get_running_loop() + self._task = loop.run_in_executor(None, self._run) + + def _run(self): + # Broadcast tensor metadata via ZMQ PUB/SUB + if self.rank == 0: + self.socket.send_string(self.topic, flags=zmq.SNDMORE) + self.socket.send_pyobj(self.metadata) + else: + self.socket.recv_string() + self.metadata = self.socket.recv_pyobj() + + # Broadcast tensor data via NCCL + _pg_broadcast(self.pg, self.bucket, src=0) + + async def wait_for_complete(self) -> dict[str, TensorMeta]: + """Wait for the broadcast operation to complete. + + Returns: + The bucket metadata after broadcast. + """ + await self._task + return self.metadata + + +@CheckpointEngineRegistry.register("nccl") +class NCCLCheckpointEngine(CheckpointEngine): + """NCCL checkpoint engine using torch.distributed ProcessGroupNCCL. + + All heavy resources (NCCL process group, ZMQ sockets, GPU buffers) are + **persistent** — they are created once during the first ``prepare()`` / + ``init_process_group()`` call and reused across subsequent syncs. + ``finalize()`` only releases buffers by default; set ``rebuild_group=True`` + if you need to tear everything down each sync. + + Args: + bucket_size: Bucket size in bytes for weight transfer. + Note: Memory overhead is 2 * bucket_size due to double buffering. + group_name: Name of the NCCL process group. + rebuild_group: Whether to destroy the NCCL group after each sync. + rollout_dtype: Target dtype for weights. + """ + + def __init__( + self, + bucket_size: int, + group_name: str = "twinkle_ckpt", + rebuild_group: bool = False, + rollout_dtype: torch.dtype = torch.bfloat16, + **kwargs, + ) -> None: + self.bucket_size = bucket_size + self.group_name = group_name + self.rebuild_group = rebuild_group + self.rollout_dtype = rollout_dtype + + # Set by Manager before prepare() via attribute assignment + self.is_master = False + self.topic = "bucket_metadata" + + # Will be set during prepare / init_process_group + self.rank = None + self.world_size = None + self.send_buf = None + self.recv_buf = None + self.socket = None + + # torch.distributed process group for checkpoint NCCL ops + self._pg: dist.ProcessGroup | None = None + self._store: dist.Store | None = None + + # Track whether resources are ready for reuse + self._prepared = False + self._group_initialized = False + + # ── ZMQ helpers ────────────────────────────────────────────────────── + + def _start_zmq_server(self): + """Start ZMQ PUB server for metadata broadcast (master only).""" + self.ip = ray.util.get_node_ip_address().strip("[]") + self.listen_port = find_free_port() + + context = zmq.Context() + self.socket = context.socket(zmq.PUB) + if is_valid_ipv6_address(self.ip): + address = f"tcp://[{self.ip}]:{self.listen_port}" + self.socket.setsockopt(zmq.IPV6, 1) + else: + address = f"tcp://{self.ip}:{self.listen_port}" + + self.socket.bind(address) + + def _connect_zmq_client(self, metadata: MasterMetadata): + """Connect to the ZMQ PUB server as a subscriber (receiver only).""" + context = zmq.Context() + self.socket = context.socket(zmq.SUB) + if is_valid_ipv6_address(metadata.zmq_ip): + address = f"tcp://[{metadata.zmq_ip}]:{metadata.zmq_port}" + self.socket.setsockopt(zmq.IPV6, 1) + else: + address = f"tcp://{metadata.zmq_ip}:{metadata.zmq_port}" + + self.socket.connect(address) + self.socket.setsockopt_string(zmq.SUBSCRIBE, self.topic) + + # ── Core lifecycle ─────────────────────────────────────────────────── + + def prepare(self) -> MasterMetadata | None: + """Allocate double buffers and start ZMQ server (master only). + + Idempotent: if buffers and ZMQ are already set up, returns cached + metadata without re-allocating. + + Returns: + MasterMetadata with ZMQ IP/port and TCPStore address if master, + else None. + """ + if self._prepared: + # Already prepared — return cached metadata + if self.is_master: + return MasterMetadata( + zmq_ip=self.ip, + zmq_port=self.listen_port, + nccl_store_host=self._nccl_store_host, + nccl_store_port=self._nccl_store_port, + ) + return None + + if self.is_master: + # Buffers on CUDA for NCCL broadcast + self.send_buf = torch.zeros( + self.bucket_size, dtype=torch.uint8, device="cuda") + self.recv_buf = torch.zeros( + self.bucket_size, dtype=torch.uint8, device="cuda") + self._start_zmq_server() + + # Allocate a TCPStore port for the checkpoint process group + self._nccl_store_host = self.ip + self._nccl_store_port = find_free_port() + + self._prepared = True + return MasterMetadata( + zmq_ip=self.ip, + zmq_port=self.listen_port, + nccl_store_host=self._nccl_store_host, + nccl_store_port=self._nccl_store_port, + ) + else: + self.send_buf = torch.zeros( + self.bucket_size, dtype=torch.uint8, device="cuda") + self.recv_buf = torch.zeros( + self.bucket_size, dtype=torch.uint8, device="cuda") + self._prepared = True + return None + + def finalize(self): + """Clean up resources after a sync. + + When ``rebuild_group=False`` (default): keeps NCCL group, ZMQ sockets, + and buffers alive for the next sync. + + When ``rebuild_group=True``: destroys NCCL group and ZMQ sockets, + forces a full re-init on the next sync. + """ + if self.rebuild_group: + # Full teardown + if self.socket is not None: + try: + self.socket.close() + except Exception as e: + logger.warning(f"Error closing ZMQ socket: {e}") + self.socket = None + + if self._pg is not None: + # Release PG by dropping references; do NOT call + # dist.destroy_process_group as the PG is unregistered. + self._pg = None + self._store = None + + self.rank = None + self.world_size = None + self.send_buf = None + self.recv_buf = None + self._prepared = False + self._group_initialized = False + + # When rebuild_group=False: keep everything alive for next sync + + @classmethod + def build_topology( + cls, + trainer_world_size: int, + rollout_world_size: int, + metadata: list[dict], + ) -> tuple[dict[str, list[Any]], dict[str, list[Any]]]: + """Build communication topology for NCCL broadcast. + + The topology assigns: + - Trainer rank 0 -> broadcast source (NCCL rank 0) + - Other trainer ranks -> rank -1 (not participating) + - Rollout workers -> ranks 1, 2, 3, ... (receivers) + + Args: + trainer_world_size: Number of trainer workers. + rollout_world_size: Number of rollout workers. + metadata: List of metadata from prepare() calls. + metadata[0] is the MasterMetadata from trainer rank 0. + + Returns: + Tuple of (trainer_kwargs, rollout_kwargs) for init_process_group(). + """ + master_metadata = metadata[0] + + trainer_kwargs = { + "rank": [0] + [-1] * (trainer_world_size - 1), + "world_size": [rollout_world_size + 1] * trainer_world_size, + "master_metadata": [master_metadata] * trainer_world_size, + } + rollout_kwargs = { + "rank": list(range(1, rollout_world_size + 1)), + "world_size": [rollout_world_size + 1] * rollout_world_size, + "master_metadata": [master_metadata] * rollout_world_size, + } + return trainer_kwargs, rollout_kwargs + + def init_process_group(self, rank: int, world_size: int, + master_metadata: MasterMetadata): + """Initialize a dedicated NCCL process group for weight synchronization. + + Creates a ``ProcessGroupNCCL`` directly (without registering it in the + default ``_World``), using a ``TCPStore`` hosted by the master for + rendezvous. This is completely independent of any existing + ``torch.distributed`` default process group. + + Idempotent: if the group is already initialized and ``rebuild_group`` + is False, this is a fast no-op. + + Args: + rank: The rank of this worker (-1 for non-participating trainers). + world_size: Total number of workers in the sync group. + master_metadata: Metadata from the master for ZMQ and store + connection. + """ + # Non-participating trainer ranks: record rank and return + if rank < 0: + self.rank = rank + self.world_size = world_size + self._group_initialized = True + return + + # Fast path: group already initialized, skip all setup + if self._group_initialized and not self.rebuild_group: + return + + if self._pg is None: + self.rank = rank + self.world_size = world_size + + # Create a dedicated TCPStore for this checkpoint group. + # Rank 0 (master / trainer) is the store server; all others + # are clients that connect to it. + is_store_master = (rank == 0) + self._store = dist.TCPStore( + host_name=master_metadata.nccl_store_host, + port=master_metadata.nccl_store_port, + world_size=world_size, + is_master=is_store_master, + wait_for_workers=True, + ) + + # Create a ProcessGroupNCCL directly — this does NOT interfere + # with the default process group or any existing torch.distributed + # state. + self._pg = dist.ProcessGroupNCCL( + self._store, rank, world_size, + ) + else: + assert self.rank == rank, f"rank {rank} != self.rank {self.rank}" + assert self.world_size == world_size, ( + f"world_size {world_size} != self.world_size {self.world_size}" + ) + + # Receivers connect to master's ZMQ PUB server + if self.rank > 0 and self.socket is None: + self._connect_zmq_client(master_metadata) + + # Barrier via broadcast to ensure all workers are ready + barrier_tensor = torch.zeros(1, dtype=torch.int32, device="cuda") + _pg_broadcast(self._pg, barrier_tensor, src=0) + torch.cuda.synchronize() + + self._group_initialized = True + logger.info( + f"init_process_group: rank={self.rank}, " + f"world_size={self.world_size}" + ) + + # ── Send / Receive ─────────────────────────────────────────────────── + + @torch.no_grad() + async def send_weights( + self, + weights: Generator[tuple[str, torch.Tensor], None, None], + ): + """Send model weights to rollout workers via NCCL broadcast. + + Uses double buffering: fill send_buf while the previous bucket + is being broadcast, then swap buffers. + + Args: + weights: A generator yielding (name, tensor) pairs. + """ + assert self.rank is not None and self.rank <= 0, ( + "Trainer workers other than rank 0 should not send weights." + ) + + # Non-participating ranks: consume the generator without sending + if self.rank < 0: + for name, weight in weights: + pass + return + + send_buf, recv_buf = self.send_buf, self.recv_buf + broadcast_op = None + + start_time = time.time() + bucket_meta: dict[str, TensorMeta] = {} + offset = 0 + + for name, weight in weights: + # Check if bucket is full + if offset + weight.nbytes > self.bucket_size: + torch.cuda.synchronize() + + # Wait for previous broadcast to finish + if broadcast_op is not None: + await broadcast_op.wait_for_complete() + + broadcast_op = BroadcastOperation( + rank=self.rank, + pg=self._pg, + bucket=send_buf, + metadata={"bucket_meta": bucket_meta, "is_last": False}, + socket=self.socket, + topic=self.topic, + ) + + # Swap buffers + send_buf, recv_buf = recv_buf, send_buf + bucket_meta = {} + offset = 0 + + assert offset + weight.nbytes <= self.bucket_size, ( + f"Weight {name}({weight.shape}, {weight.dtype}) is too large " + f"for bucket ({self.bucket_size / 1e6:.1f} MB). " + f"Increase bucket_size." + ) + + bucket_meta[name] = { + "name": name, + "shape": weight.shape, + "dtype": weight.dtype, + "offset": offset, + } + + # Copy weight to buffer (both buffers are on CUDA) + send_buf[offset:offset + weight.nbytes].copy_( + weight.view(-1).view(torch.uint8), non_blocking=True + ) + offset += weight.nbytes + + # Broadcast final bucket + torch.cuda.synchronize() + if broadcast_op is not None: + await broadcast_op.wait_for_complete() + + broadcast_op = BroadcastOperation( + rank=self.rank, + pg=self._pg, + bucket=send_buf, + metadata={"bucket_meta": bucket_meta, "is_last": True}, + socket=self.socket, + topic=self.topic, + ) + await broadcast_op.wait_for_complete() + + logger.info( + f"Rank {self.rank} send weights done, " + f"time cost: {time.time() - start_time:.2f}s" + ) + + @torch.no_grad() + async def receive_weights( + self, + ) -> AsyncGenerator[tuple[str, torch.Tensor], None]: + """Receive model weights from trainer via NCCL broadcast. + + Uses double buffering: receive into recv_buf while processing + send_buf, then swap. + + Yields: + Tuples of (name, tensor) for each weight. The tensor is a + *view* into the receive buffer -- callers that need to keep it + should clone it. + """ + assert self.rank is not None and self.rank > 0, ( + "Rank 0 should not receive weights." + ) + + send_buf, recv_buf = self.send_buf, self.recv_buf + total_bytes, total_params = 0, 0 + + # Receive first bucket + start_time = time.time() + broadcast_op = BroadcastOperation( + rank=self.rank, + pg=self._pg, + bucket=recv_buf, + metadata=None, + socket=self.socket, + topic=self.topic, + ) + metadata = await broadcast_op.wait_for_complete() + total_bytes += self.bucket_size + total_params += len(metadata["bucket_meta"]) + + # Swap buffers + send_buf, recv_buf = recv_buf, send_buf + + while not metadata["is_last"]: + # 1. Start receiving next bucket + broadcast_op = BroadcastOperation( + rank=self.rank, + pg=self._pg, + bucket=recv_buf, + metadata=None, + socket=self.socket, + topic=self.topic, + ) + + # 2. Yield tensors from current buffer (send_buf) + for name, meta in metadata["bucket_meta"].items(): + dtype, shape = meta["dtype"], meta["shape"] + size = dtype.itemsize * shape.numel() + tensor = send_buf[ + meta["offset"]:meta["offset"] + size + ].view(dtype=dtype).view(shape) + yield name, tensor + + # 3. Wait for next bucket + metadata = await broadcast_op.wait_for_complete() + total_bytes += self.bucket_size + total_params += len(metadata["bucket_meta"]) + + # 4. Swap buffers + torch.cuda.synchronize() + send_buf, recv_buf = recv_buf, send_buf + + # Yield tensors from final bucket + for name, meta in metadata["bucket_meta"].items(): + dtype, shape = meta["dtype"], meta["shape"] + size = dtype.itemsize * shape.numel() + tensor = send_buf[ + meta["offset"]:meta["offset"] + size + ].view(dtype=dtype).view(shape) + yield name, tensor + + elapsed = time.time() - start_time + bandwidth = total_bytes / elapsed / (1024 * 1024 * 1024) + logger.info( + f"receive_weights done: rank={self.rank}, " + f"params={total_params}, " + f"time={elapsed:.2f}s, bandwidth={bandwidth:.2f} GB/s" + ) diff --git a/src/twinkle/infra/_ray/resource_manager.py b/src/twinkle/infra/_ray/resource_manager.py index 0e753fca..416e016b 100644 --- a/src/twinkle/infra/_ray/resource_manager.py +++ b/src/twinkle/infra/_ray/resource_manager.py @@ -137,10 +137,10 @@ def __init__(self, ranks = group.ranks gpus_per_worker = getattr(group, 'gpus_per_worker', 1) local_device_groups = [] - # ranks is only used to declare "how many processes/devices", should not participate in physical device mapping. - # When visible devices are already trimmed by ASCEND_RT_VISIBLE_DEVICES etc., - # logical ranks may be non-contiguous like [8,9], need to normalize them in order. - normalized_ranks = list(range(len(ranks))) + # Use original ranks for GPU mapping so each DeviceGroup maps to + # the correct physical devices. E.g. ranks=[2,3] with + # nproc_per_node=4 should map to gpu_rank [2,3], not [0,1]. + normalized_ranks = list(ranks) if gpus_per_worker > 1: if len(normalized_ranks) % gpus_per_worker != 0: diff --git a/src/twinkle/loss/grpo.py b/src/twinkle/loss/grpo.py index 222a001c..0d2270a9 100644 --- a/src/twinkle/loss/grpo.py +++ b/src/twinkle/loss/grpo.py @@ -1,5 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from typing import Dict, Optional, List, TYPE_CHECKING +from typing import Dict, Optional, List, TYPE_CHECKING, Union from twinkle.loss.base import Loss from twinkle.utils.torch_utils import selective_log_softmax @@ -135,30 +135,192 @@ def _aggregate_loss( # Mean over tokens per sequence, then mean over batch return ( (per_token_loss * loss_mask).sum(-1) - / loss_mask.sum(-1).clamp(min=1.0) + # clip_grad_norm normalized by the number of tokens in the batch, skip + # / loss_mask.sum(-1).clamp(min=1.0) ).mean() + @staticmethod + def _pad_and_align_logps( + logps: 'Union[torch.Tensor, List[List[float]]]', + loss_mask: 'torch.Tensor', + device: 'torch.device', + dtype: 'torch.dtype', + ) -> 'torch.Tensor': + """Align auxiliary log-probabilities to the model's padded sequence. + + ``old_logps`` and ``ref_logps`` come from the sampler or a reference + model and are **not** processed by the ``InputProcessor`` pipeline. + They only cover **response tokens** (variable length per sample), + whereas the model's ``logps`` covers the entire padded sequence + (prompt + response + collation padding + CP/TP padding). + + This method scatters the compact response-only log-probs into the + correct positions within the full padded sequence, using ``loss_mask`` + to determine where response tokens are located. + + If ``logps`` is already a ``[batch, seq_len]`` tensor whose + ``seq_len`` matches the target, it is returned as-is (assumed to be + pre-aligned, e.g. from a reference model that ran the same pipeline). + + Args: + logps: Per-token log probabilities, either: + - ``List[List[float]]``: ragged per-sample response log-probs + - ``torch.Tensor [batch, response_len]``: compact response- + only tensor (may have different length than target) + - ``torch.Tensor [batch, seq_len]``: pre-aligned tensor + loss_mask: ``[batch, seq_len]`` bool tensor indicating response + token positions within the padded sequence. + device: Target device. + dtype: Target dtype. + + Returns: + Tensor of shape ``[batch, seq_len]`` with log-probs placed at + positions where ``loss_mask`` is True, and 0.0 elsewhere. + """ + import torch + + batch_size, target_seq_len = loss_mask.shape + + # ── 1. Normalize input to a list of per-sample 1-D tensors ─────── + if isinstance(logps, torch.Tensor): + logps = logps.to(device=device, dtype=dtype) + if logps.dim() == 1: + logps = logps.unsqueeze(0) + # Already aligned: same batch & seq_len → fast path + if logps.shape == (batch_size, target_seq_len): + return logps + # Compact tensor: split into per-sample tensors + per_sample = [logps[i] for i in range(logps.shape[0])] + elif isinstance(logps, (list, tuple)): + per_sample = [ + torch.as_tensor(seq, dtype=dtype, device=device) + for seq in logps + ] + else: + raise TypeError(f"Unsupported logps type: {type(logps)}") + + # ── 2. Scatter into loss_mask positions ────────────────────────── + result = torch.zeros(batch_size, target_seq_len, dtype=dtype, device=device) + for i in range(batch_size): + positions = loss_mask[i].nonzero(as_tuple=True)[0] + n_response = positions.shape[0] + n_logps = per_sample[i].shape[0] + n = min(n_response, n_logps) + if n > 0: + result[i, positions[:n]] = per_sample[i][:n].to(device=device, dtype=dtype) + + return result + + @staticmethod + def _unpack_packed_logps( + logps: 'torch.Tensor', + loss_mask: 'torch.Tensor', + position_ids: 'Optional[torch.Tensor]', + num_sequences: int, + ) -> 'tuple': + """Unpack packed (padding_free) tensors into per-sequence batch format. + + In padding_free / packing mode, the processor concatenates all + sequences into a single row: ``[1, total_tokens]``. This method + splits them back into ``[num_sequences, max_seq_len]`` so that + per-sequence operations (advantages broadcast, loss aggregation) + work correctly. + + Sequence boundaries are detected from ``position_ids`` (which + resets to 0 at each boundary). If ``position_ids`` is unavailable, + the method falls back to detecting contiguous non-masked (prompt) + gaps in the packed ``loss_mask``. + + Args: + logps: ``[1, total_tokens]`` packed log-probabilities. + loss_mask: ``[1, total_tokens]`` packed loss mask. + position_ids: ``[1, total_tokens]`` packed position ids, or None. + num_sequences: Expected number of sequences in the pack. + + Returns: + ``(logps, loss_mask)`` each of shape + ``[num_sequences, max_seq_len]``, right-padded with 0. + """ + import torch + + total_len = logps.shape[1] + logps_flat = logps.squeeze(0) # [total_tokens] + mask_flat = loss_mask.squeeze(0) # [total_tokens] + + # ── Find sequence boundaries ───────────────────────────────────── + if position_ids is not None: + pos_flat = position_ids.squeeze(0) # [total_tokens] + # position_ids resets to 0 at each new sequence + boundary_indices = (pos_flat == 0).nonzero(as_tuple=True)[0] + else: + # Fallback: use loss_mask transitions. Each sequence has a + # prompt region (mask=0) followed by a response region (mask=1). + # Detect 0→1 transitions preceded by a 0→0 gap (new prompt). + # Simpler: find where mask goes from 1→0→...→0→1 (prompt gap). + # We mark boundaries at the start of each prompt (first 0 after 1). + shifted = torch.cat([torch.tensor([False], device=mask_flat.device), mask_flat[:-1]]) + # Start of a new sequence: transition from mask=1 (end of prev response) + # to mask=0 (start of next prompt), or position 0 for the first sequence. + prompt_starts = ((~mask_flat) & shifted).nonzero(as_tuple=True)[0] + boundary_indices = torch.cat([ + torch.tensor([0], device=mask_flat.device), + prompt_starts, + ]) + + # Deduplicate & sort + boundary_indices = boundary_indices.unique(sorted=True) + + # Add end sentinel + boundaries = torch.cat([ + boundary_indices, + torch.tensor([total_len], device=boundary_indices.device), + ]) + + # ── Split and pad ──────────────────────────────────────────────── + seq_logps = [] + seq_masks = [] + n_seqs = min(boundaries.shape[0] - 1, num_sequences) + for i in range(n_seqs): + start = boundaries[i].item() + end = boundaries[i + 1].item() + seq_logps.append(logps_flat[start:end]) + seq_masks.append(mask_flat[start:end]) + + max_len = max(s.shape[0] for s in seq_logps) + padded_logps = torch.zeros(n_seqs, max_len, dtype=logps.dtype, device=logps.device) + padded_masks = torch.zeros(n_seqs, max_len, dtype=loss_mask.dtype, device=loss_mask.device) + for i in range(n_seqs): + L = seq_logps[i].shape[0] + padded_logps[i, :L] = seq_logps[i] + padded_masks[i, :L] = seq_masks[i] + + return padded_logps, padded_masks + def __call__( self, inputs: Dict, outputs: Dict, *, - old_logps: Optional['torch.Tensor'] = None, + old_logps: Optional[Union['torch.Tensor', List[List[float]]]] = None, ref_logps: Optional['torch.Tensor'] = None, - trajectories: Optional[List[Trajectory]] = None, + trajectories: Optional[List[Trajectory]] = None, # TODO: remove this argument + advantages: Optional['torch.Tensor'] = None, **kwargs, ) -> 'torch.Tensor': """ Compute GRPO loss. Args: - inputs: Dict containing 'input_ids' and 'labels' [batch, seq_len] + inputs: Dict containing 'input_ids' and 'labels' [batch, seq_len]. + In packing mode, also expects 'position_ids' [1, total_tokens]. outputs: Dict containing either: - 'logps'/'log_probs': [batch, seq_len] pre-computed log probs, OR - 'logits': [batch, seq_len, vocab] from which logps will be computed old_logps: [batch, seq_len] or List[List[float]] log probs from old/sampling policy. - If None, uses current logps (on-policy, ratio=1). - ref_logps: Optional [batch, seq_len] reference model log probs for KL penalty + Can have ragged per-sample lengths — will be padded and aligned + automatically. If None, uses current logps (on-policy, ratio=1). + ref_logps: Optional [batch, seq_len] reference model log probs for KL penalty. + Same padding/alignment rules as old_logps. trajectories: Optional List[Trajectory] containing advantages **kwargs: Additional arguments @@ -168,9 +330,6 @@ def __call__( import torch labels = inputs.get('labels') assert labels is not None, "inputs must contain 'labels'" - # todo: check data_collator return labels as tensor - if labels is None: - raise ValueError("inputs must contain 'labels'") if not torch.is_tensor(labels): labels = torch.as_tensor(labels) if labels.dim() == 1: @@ -179,43 +338,67 @@ def __call__( logits = outputs.get('logits') if logits.shape[1] != labels.shape[1]: # some mllm return logits with image tokens, exclude here - logits = logits[:, :-labels.shape[1]:] + logits = logits[:, -labels.shape[1]:] - labels = torch.roll(labels, shifts=-1, dims=1) + # labels = torch.roll(labels, shifts=-1, dims=1) + # breakpoint() loss_mask = (labels != self.ignore_index).bool() masked_labels = labels.clone() masked_labels[~loss_mask] = 0 logps = selective_log_softmax(logits, masked_labels) + del logits + device = logps.device + # ── Detect and handle packing mode ────────────────────────────── + # In padding_free / packing mode the processor concatenates all + # sequences into a single row [1, total_tokens]. We detect this + # by checking: batch_size == 1 but the actual number of sequences + # (known from trajectories or advantages) is greater than 1. + if trajectories is not None: + num_sequences = len(trajectories) + elif advantages is not None: + num_sequences = len(advantages) if isinstance(advantages, (list, tuple)) else advantages.shape[0] + else: + num_sequences = logps.shape[0] + is_packed = (logps.shape[0] == 1 and num_sequences > 1) + if is_packed: + position_ids = inputs.get('position_ids') + logps, loss_mask = self._unpack_packed_logps( + logps, loss_mask, position_ids, num_sequences, + ) + + # ── Prepare old_logps ──────────────────────────────────────────── + # old_logps may be ragged (List[List[float]]) containing only + # response-token log-probs, whereas logps covers the full padded + # sequence. _pad_and_align_logps scatters them into the correct + # positions using loss_mask. if old_logps is None: - # On-policy old_logps = logps.detach() - - assert trajectories is not None, "trajectories must be provided" - # TODO: just pass advantages? - advantages = self._extract_advantages_from_trajectories(trajectories, device) + else: + old_logps = self._pad_and_align_logps( + old_logps, loss_mask, device, logps.dtype, + ) + + # ── Prepare ref_logps (same treatment) ────────────────────────── + if ref_logps is not None: + ref_logps = self._pad_and_align_logps( + ref_logps, loss_mask, device, logps.dtype, + ) + assert advantages is not None, \ + "advantages must be provided (pass as kwarg to forward_backward)" if not torch.is_tensor(advantages): advantages = torch.as_tensor(advantages, device=device, dtype=torch.float32) else: advantages = advantages.to(device, dtype=torch.float32) - # Ensure advantages is 2D for broadcasting + # Ensure advantages is 2D for broadcasting [batch, 1] if advantages.dim() == 1: advantages = advantages.unsqueeze(1) - # Align shapes - assert logps.shape[1] == old_logps.shape[1] == loss_mask.shape[1], ( - "logps, old_logps, and loss_mask must have the same sequence length" - f"but got {logps.shape[1]}, {old_logps.shape[1]}, {loss_mask.shape[1]} respectively") - - if ref_logps is not None: - assert ref_logps.shape[1] == logps.shape[1], ( - "ref_logps must have the same sequence length as logps" - f"but got {ref_logps.shape[1]}, {logps.shape[1]} respectively") - + # ── Compute loss ──────────────────────────────────────────────── log_importance_weights = self._compute_log_importance_weights(logps, old_logps, loss_mask) ratio = torch.exp(log_importance_weights) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index b9f64af2..d2426cee 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -27,6 +27,7 @@ from twinkle.hub import HubOperation from twinkle.loss import Loss, VocabParallelCrossEntropyLoss from twinkle.metric import Metric, LossMetric, TrainMetric +from twinkle.checkpoint_engine.mixin import CheckpointEngineMixin from twinkle.model.base import TwinkleModel from twinkle.processor import InputProcessor from twinkle.template import Template @@ -118,7 +119,7 @@ def calculate_metrics(self, is_training): @remote_class(execute='all') -class MegatronModel(TwinkleModel, nn.Module): +class MegatronModel(TwinkleModel, nn.Module, CheckpointEngineMixin): def __init__( self, @@ -344,9 +345,16 @@ def forward_backward(self, else: seq_length = original_seq_length + loss_extra_kwargs = kwargs + def post_loss_function(output_tensor, inputs): outputs = ModelOutput(logits=output_tensor) - losses, counts = loss_instance(inputs, outputs) + result = loss_instance(inputs, outputs, **loss_extra_kwargs) + if isinstance(result, tuple): + losses, counts = result + else: + losses = result + counts = torch.tensor(1, device=losses.device) return self.strategy.gather_loss_for_cp(losses, counts, output_tensor) # Define forward step function for Megatron @@ -904,6 +912,25 @@ def get_hf_state_dict(self, adapter_name: str = '') -> Generator[Tuple[str, torc tqdm_desc='Weight sync: ', ) + @remote_function(dispatch='all', collect='first', sync=True) + def export_weights_dict(self, adapter_name: str = '', lora_only: bool = False) -> Dict[str, torch.Tensor]: + """Export model weights as a dict via Ray object store. + + Collects all weights from ``get_hf_state_dict()`` into a dict. + Used as an alternative to NCCL checkpoint engine for weight sync. + + Args: + adapter_name: Adapter name for weight extraction. + lora_only: If True, only export LoRA adapter weights. + + Returns: + Dict mapping parameter names to CPU tensors. + """ + weights = {} + for name, tensor in self.get_hf_state_dict(adapter_name=adapter_name if lora_only else ''): + weights[name] = tensor.cpu() + return weights + def _patch_adapter(self, adapter_name: str, config_or_dir: Union[PeftConfig, str, Dict[str, Any]], **kwargs): from .tuners.utils import set_linear_is_expert, get_target_modules, patch_deepcopy assert adapter_name, 'Use a non-empty adapter_name' @@ -1009,6 +1036,8 @@ def set_processor(self, processor_cls: Union[InputProcessor, Type[InputProcessor adapter_name = kwargs.pop('adapter_name', self._get_default_group()) optimizer_config = self.optimizer_group[adapter_name] kwargs['framework'] = 'megatron' + # processor/base.py: self.device_mesh.cp_world_size + kwargs['device_mesh'] = kwargs.get('device_mesh', self.device_mesh) optimizer_config.processor = construct_class(processor_cls, InputProcessor, twinkle.processor, **kwargs) @remote_function(execute='first') @@ -1092,3 +1121,105 @@ def _bridge(self) -> GPTBridge: self._bridge_instance = megatron_model_meta.bridge_cls() return self._bridge_instance + + # ── Checkpoint Engine (from CheckpointEngineMixin) ────────────────── + # prepare_checkpoint_engine, init_checkpoint_process_group, and + # finalize_checkpoint_engine are inherited from CheckpointEngineMixin. + # + # Key difference from TransformersModel: Megatron uses TP/PP, so + # get_hf_state_dict() internally performs TP allgather and handles PP + # layer distribution. All model ranks MUST execute the weight generator + # concurrently for the collective communications to complete. Only + # model_actor[0] (rank=0 in the checkpoint engine) actually broadcasts + # via NCCL; others consume the generator silently (rank=-1). + + @remote_function(dispatch='all') + def send_weights_via_checkpoint_engine( + self, + adapter_name: str = '', + base_sync_done: bool = False, + ): + """Send model weights via NCCL broadcast. + + Uses ``get_hf_state_dict()`` to convert Megatron-format weights to + HuggingFace format on-the-fly. The bridge's ``export_weights`` + internally handles TP allgather and PP layer distribution, so all + model ranks must execute this method concurrently. + + LoRA-aware sending: + - ``base_sync_done=False``: Send all base model weights (HF format). + LoRA-specific weights (lora_A, lora_B) are filtered out. + - ``base_sync_done=True`` and ``adapter_name`` set: Send only LoRA + adapter weights. + + Args: + adapter_name: Adapter name for LoRA weight identification. + base_sync_done: If True, only send LoRA adapter weights. + """ + import asyncio + import logging + import threading + + logger = logging.getLogger(__name__) + engine = self._get_or_create_checkpoint_engine() + + if base_sync_done and adapter_name: + # ── LoRA-only mode ──────────────────────────────────────────── + # Export only LoRA adapter weights via the bridge in PEFT format + def weight_generator(): + for name, tensor in self.get_hf_state_dict(adapter_name=adapter_name): + if name is not None and tensor is not None: + yield name, tensor + + logger.info("Sending LoRA-only weights (Megatron)") + else: + # ── Full model mode ─────────────────────────────────────────── + # Export all base weights via the bridge (HF format) + def weight_generator(): + for name, tensor in self.get_hf_state_dict(adapter_name=''): + if name is None or tensor is None: + continue + # Skip LoRA-specific weights for base model sync + if 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name: + continue + yield name, tensor + + async def _send(): + await engine.send_weights(weight_generator()) + + result_container = {'error': None} + + def _run(): + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(_send()) + finally: + loop.close() + except Exception as e: + result_container['error'] = e + + thread = threading.Thread(target=_run) + thread.start() + thread.join() + + if result_container['error'] is not None: + raise result_container['error'] + + @remote_function(collect='first') + def get_peft_config_dict(self, adapter_name: str = '') -> dict: + """Return the PEFT config as a dict for vLLM's PEFTHelper. + + Used by CheckpointEngineManager for LoRA-only weight sync. + + Returns: + PEFT config dict, or None if no LoRA adapter is present. + """ + optimizer_config = self.optimizer_group.get(adapter_name) + if optimizer_config is None or optimizer_config.adapter_config is None: + return None + config = optimizer_config.adapter_config + if isinstance(config, dict): + config = config.get(adapter_name, next(iter(config.values()))) + return config.to_dict() if hasattr(config, 'to_dict') else dict(config) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 0d2c9aff..561a36fb 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -33,6 +33,7 @@ from twinkle.template import Template from twinkle.utils import torch_util, construct_class from twinkle.utils.grad_clip import normalize_and_clip_grad_norm +from twinkle.checkpoint_engine.mixin import CheckpointEngineMixin from twinkle.model.base import TwinkleModel from twinkle.model.transformers.moe import apply_expert_parallel from twinkle.model.transformers.strategy import AccelerateStrategy, NativeFSDPStrategy @@ -139,7 +140,7 @@ def calculate_metrics(self, is_training): DEFAULT_WEIGHT_DECAY = 0.01 @remote_class() -class TransformersModel(TwinkleModel, PreTrainedModel): +class TransformersModel(TwinkleModel, PreTrainedModel, CheckpointEngineMixin): """The transformers model wrapper. Args: @@ -197,6 +198,7 @@ def __init__(self, # noqa model_id = HubOperation.download_model(model_id) self.model = model_cls.from_pretrained(model_id, config=config, **kwargs) # Construct sequence-parallel strategy lazily during wrapping to reduce init-time side effects. + self.model.gradient_checkpointing_enable() self.sp_strategy = None self._model_wrapped = False self.optimizer_group: Dict[str, OptimizerGroup] = {_default_adapter_name: self._construct_default_optimizer_group()} @@ -325,6 +327,7 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajec adapter_name = kwargs.pop('adapter_name', self._get_default_group()) optimizer_config = self.optimizer_group[adapter_name] self._lazy_wrap_model() + self.model.train() if (isinstance(inputs, dict) and self._not_encoded(inputs)) or (isinstance(inputs, list) and self._not_encoded(inputs[0])): # Trajectory or List[Trajectory] assert optimizer_config.template is not None, \ @@ -364,6 +367,7 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T adapter_name = kwargs.pop('adapter_name', self._get_default_group()) optimizer_config = self.optimizer_group[adapter_name] self._lazy_wrap_model() + self.model.eval() if (isinstance(inputs, dict) and self._not_encoded(inputs)) or (isinstance(inputs, list) and self._not_encoded(inputs[0])): # Trajectory or List[Trajectory] assert optimizer_config.template is not None, \ @@ -867,7 +871,29 @@ def _load_optimizer(self, checkpoint_dir, **kwargs): @remote_function(collect='first') def get_state_dict(self, **kwargs): - return self._get_trainable_parameters(kwargs.pop('adapter_name', self._get_default_group())) + return self.strategy.unwrap_model(self.model).state_dict() + + @remote_function(collect='first') + def get_peft_config_dict(self, adapter_name: str = None) -> dict: + """Return the PEFT config as a dict for vLLM's PEFTHelper. + + Used by CheckpointEngineManager to pass peft_config to the sampler + when doing LoRA-only weight sync. + + Returns: + PEFT config dict, or None if the model has no LoRA adapter. + """ + if adapter_name is None: + adapter_name = self._get_default_group() + optimizer_config = self.optimizer_group.get(adapter_name) + if optimizer_config is None or optimizer_config.adapter_config is None: + return None + config = optimizer_config.adapter_config + # PeftConfig can be a dict-like mapping (e.g. {adapter_name: LoraConfig}) + # or a single LoraConfig. Normalize to a single config. + if isinstance(config, dict): + config = config.get(adapter_name, next(iter(config.values()))) + return config.to_dict() if hasattr(config, 'to_dict') else dict(config) @remote_function(collect='first') def calculate_metric(self, is_training, **kwargs): @@ -1027,3 +1053,97 @@ def get_train_configs(self, **kwargs) -> str: f'Trainable params: {trainable_params:,d} || all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param:.4f}%\n' ) return expr + + # ========================================================================= + # Checkpoint Engine — Weight Sync (from CheckpointEngineMixin) + # ========================================================================= + # prepare_checkpoint_engine, init_checkpoint_process_group, and + # finalize_checkpoint_engine are inherited from CheckpointEngineMixin. + # Only send_weights_via_checkpoint_engine is model-specific. + + @remote_function(dispatch='all') + def send_weights_via_checkpoint_engine( + self, + adapter_name: str = '', + base_sync_done: bool = False, + ): + """Send model weights via NCCL broadcast. + + Called on ALL model workers: + - Rank 0 (master): Collects weights and broadcasts via NCCL. + - Other ranks (rank=-1): Consume the generator without sending. + + LoRA-aware sending (follows verl's design): + - ``base_sync_done=False``: Send ALL base weights from the full state + dict. Names are sent as-is (including ``.base_layer`` for PEFT + models); the **sampler** side handles stripping ``.base_layer`` + based on its vLLM ``enable_lora`` setting. + LoRA-specific weights (lora_A, lora_B) are filtered out. + - ``base_sync_done=True`` and ``adapter_name`` is set: Send only LoRA + adapter weights via ``get_peft_model_state_dict()``. The sampler + converts names to vLLM LoRA format for ``add_lora()``. + + Args: + adapter_name: Adapter name for LoRA weight identification. + base_sync_done: If True, only send LoRA adapter weights. + """ + import asyncio + import threading + from twinkle.utils.framework import Torch + + engine = self._get_or_create_checkpoint_engine() + + # Get state dict from unwrapped model + model = self.strategy.unwrap_model(self.model) + + if base_sync_done and adapter_name: + # ── LoRA-only mode: send only adapter weights ──────────────── + # Use PEFT's get_peft_model_state_dict for clean LoRA extraction + from peft.utils import get_peft_model_state_dict + lora_state_dict = get_peft_model_state_dict(model) + + def weight_generator(): + for name, tensor in lora_state_dict.items(): + tensor = Torch.to_local_tensor(tensor) + yield name, tensor + + else: + # ── Full model mode: send all weights (base model sync) ────── + state_dict = model.state_dict() + + def weight_generator(): + for name, tensor in state_dict.items(): + # Skip LoRA-specific weights for base model sync + if 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name: + continue + tensor = Torch.to_local_tensor(tensor) + # Keep original names (including .base_layer for PEFT models). + # The sampler side will strip .base_layer based on whether + # vLLM has enable_lora=True/False. + yield name, tensor + + # Run async send_weights in a dedicated event loop thread. + # We cannot use the Ray worker's event loop because it may already + # be occupied, and send_weights uses run_in_executor internally. + async def _send(): + await engine.send_weights(weight_generator()) + + result_container = {'error': None} + + def _run(): + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(_send()) + finally: + loop.close() + except Exception as e: + result_container['error'] = e + + thread = threading.Thread(target=_run) + thread.start() + thread.join() + + if result_container['error'] is not None: + raise result_container['error'] diff --git a/src/twinkle/patch/vllm_lora_weights.py b/src/twinkle/patch/vllm_lora_weights.py index 733f3edd..69ab2a6c 100644 --- a/src/twinkle/patch/vllm_lora_weights.py +++ b/src/twinkle/patch/vllm_lora_weights.py @@ -36,7 +36,7 @@ def __call__(self, sampler, **kwargs): def _get_tokenizer(): """Get tokenizer lazily from sampler's template.""" - if _sampler_ref.template is not None: + if _sampler_ref and _sampler_ref.template is not None: return _sampler_ref.template.tokenizer return None diff --git a/src/twinkle/preprocessor/__init__.py b/src/twinkle/preprocessor/__init__.py index 0c797c7b..3c769c4a 100644 --- a/src/twinkle/preprocessor/__init__.py +++ b/src/twinkle/preprocessor/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .base import Preprocessor, DataFilter -from .llm import CompetitionMathProcessor, CompetitionMathGRPOProcessor, SelfCognitionProcessor, AlpacaProcessor +from .llm import CompetitionMathProcessor, CompetitionMathGRPOProcessor, SelfCognitionProcessor, AlpacaProcessor, CountdownProcessor diff --git a/src/twinkle/preprocessor/llm.py b/src/twinkle/preprocessor/llm.py index 5a88a751..908fdd13 100644 --- a/src/twinkle/preprocessor/llm.py +++ b/src/twinkle/preprocessor/llm.py @@ -57,3 +57,24 @@ def __call__(self, row) -> Trajectory: Message(role='assistant', content=output_text), ] return Trajectory(messages=messages) + +class CountdownProcessor(Preprocessor): + system_prompt = ( + "You are a helpful assistant. You first thinks about the reasoning process " + "in the mind and then provides the user with the answer." + ) + def __call__(self, row) -> Trajectory: + nums = row.get('nums', []) + target = row.get('response', row.get('target', 0)) + + query = f"""Using the numbers {nums}, create an equation that equals {target}. +You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. +Show your work in tags. And return the final equation and answer in tags, +for example (1 + 2) / 3 * 4 = 4 .""" + + messages = [ + Message(role='system', content=self.system_prompt), + Message(role='user', content=query), + Message(role='assistant', content=''), + ] + return Trajectory(messages=messages, user_data=[{'target': target, 'nums': nums}]) \ No newline at end of file diff --git a/src/twinkle/processor/grpo.py b/src/twinkle/processor/grpo.py index 98d3ab69..532d1729 100644 --- a/src/twinkle/processor/grpo.py +++ b/src/twinkle/processor/grpo.py @@ -13,7 +13,6 @@ from twinkle import DeviceMesh, remote_class from twinkle.processor import InputProcessor -# TODO: remove @remote_class() class GRPOLossProcessor(InputProcessor): """ diff --git a/src/twinkle/sampler/base.py b/src/twinkle/sampler/base.py index f779c828..325e404e 100644 --- a/src/twinkle/sampler/base.py +++ b/src/twinkle/sampler/base.py @@ -89,6 +89,8 @@ def _check_adapter_valid(self, adapter_name: str): assert adapter_name in self.sample_group, \ f'Invalid adapter_name: {adapter_name}. Available: {list(self.sample_group.keys())}' + # used in grpo demo, TODO: remove remote_function + @remote_function(dispatch='all', collect='first', lazy_collect=False) def _get_template(self, adapter_name: str = '') -> Optional[Template]: if adapter_name and adapter_name in self.sample_group: template = self.sample_group[adapter_name].template diff --git a/src/twinkle/sampler/vllm_engine.py b/src/twinkle/sampler/vllm_engine.py index 4bc993c4..422efa45 100644 --- a/src/twinkle/sampler/vllm_engine.py +++ b/src/twinkle/sampler/vllm_engine.py @@ -191,6 +191,7 @@ async def sample( topk_prompt_logprobs: int = 0, adapter_path: Optional[str] = None, adapter_user_id: Optional[str] = None, + lora_request: Optional[Any] = None, request_id: Optional[str] = None, priority: int = 0, *, @@ -211,6 +212,8 @@ async def sample( topk_prompt_logprobs: If > 0, returns top-k logprobs for each prompt token. adapter_path: Resolved filesystem path to LoRA adapter directory. adapter_user_id: User identifier for the adapter (for tracking loaded adapters). + lora_request: Pre-built LoRARequest for RL training weight sync. + Takes precedence over adapter_path. request_id: Optional request ID for tracking. priority: Request priority (higher = more urgent). images: Optional list of images for multimodal models. @@ -253,9 +256,9 @@ async def sample( else: prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) - # Build LoRA request if adapter_path provided - lora_request = None - if adapter_path and self.enable_lora: + # Build LoRA request: pre-built lora_request takes precedence, + # then adapter_path for multi-tenant mode. + if lora_request is None and adapter_path and self.enable_lora: lora_request = await self._get_or_load_lora(adapter_path, adapter_user_id) # Generate @@ -463,7 +466,7 @@ async def sleep(self, level: int = 2) -> None: await self.engine.sleep(level=level) logger.debug(f"Engine sleeping at level {level}") - async def wake_up(self, tags: Optional[List[str]] = None, reload_weights: bool = False) -> None: + async def wake_up(self, tags: Optional[List[str]] = None) -> None: """ Resume weights and/or KV cache to GPU memory. @@ -485,15 +488,6 @@ async def wake_up(self, tags: Optional[List[str]] = None, reload_weights: bool = await self.engine.wake_up(tags=tags) - if reload_weights and "weights" in tags: - try: - await self.engine.collective_rpc("reload_weights") - logger.debug("Weights reloaded after wake_up") - except Exception as e: - logger.warning(f"Failed to reload weights: {e}") - - await self.clear_kv_cache() - logger.debug(f"Engine waking up with tags: {tags}") async def clear_kv_cache(self) -> None: @@ -503,15 +497,193 @@ async def clear_kv_cache(self) -> None: elif hasattr(self.engine, 'reset_mm_cache'): await self.engine.reset_mm_cache() # Do we need this? + async def reset_prefix_cache(self) -> None: + await self.engine.reset_prefix_cache() + async def update_weights( self, weights: Dict[str, torch.Tensor], - adapter_name: Optional[str] = None, + peft_config: Optional[dict] = None, + base_sync_done: bool = False, + bucket_size_mb: int = 2048, **kwargs, ) -> None: - # not use, TODO: remove this method - await self.engine.model_runner.model.load_weights(weights) + """Update model weights via ZMQ + CUDA IPC/SHM to worker extension. + + vLLM v1 AsyncLLM runs the model in a separate ``WorkerProc`` subprocess. + We use a ZMQ REQ/REP + CUDA IPC (or SHM for CPU tensors) channel to + stream weights in buckets to the worker, avoiding pickle serialization + overhead for large tensor dicts. + + LoRA-aware: + - ``base_sync_done=False``: Weights loaded via ``load_weights()`` + in the vLLM worker (base model sync). + - ``base_sync_done=True`` with ``peft_config``: Weights loaded as + a LoRA adapter via ``add_lora()`` in the vLLM worker. + + Args: + weights: Dict mapping weight names to tensors (GPU or CPU). + peft_config: PEFT config dict for LoRA adapter loading. + base_sync_done: If True with peft_config, load as LoRA adapter. + bucket_size_mb: Size of transfer bucket in MB. + """ + import gc + import time + import zmq + import asyncio + from concurrent.futures import ThreadPoolExecutor + from vllm.platforms import current_platform + + start_time = time.time() + + # Determine if weights are on GPU (use CUDA IPC) or CPU (use SHM) + first_weight = next(iter(weights.values())) + use_gpu_ipc = first_weight.is_cuda + use_shm = not use_gpu_ipc + + # Get device UUID for ZMQ handle + device_uuid = current_platform.get_device_uuid(0) + zmq_handle = f"ipc:///tmp/twinkle-ipc-{device_uuid}.sock" + + bucket_size = bucket_size_mb << 20 + + # Create transfer buffer + buffer = None + shm = None + + if use_gpu_ipc: + from torch.multiprocessing.reductions import reduce_tensor + buffer = torch.empty(bucket_size, dtype=torch.uint8, device=first_weight.device) + ipc_handle = reduce_tensor(buffer) + else: + from multiprocessing import shared_memory + shm_name = f"twinkle_weights_{uuid.uuid4().hex}" + shm = shared_memory.SharedMemory(name=shm_name, create=True, size=bucket_size) + buffer = torch.frombuffer(shm.buf, dtype=torch.uint8) + + # Setup ZMQ socket FIRST (bind before worker connects) + zmq_ctx = zmq.Context() + socket = zmq_ctx.socket(zmq.REQ) + socket.bind(zmq_handle) + + def _send_weights_via_zmq(): + """Send weights via ZMQ in a separate thread.""" + if use_gpu_ipc: + socket.send_pyobj(ipc_handle) + else: + socket.send_pyobj({"name": shm_name, "size": bucket_size}) + socket.recv() # Wait for worker ready + + offset = 0 + bucket_meta = {} + + for name, weight in weights.items(): + if use_shm and weight.is_cuda: + weight = weight.cpu() + + if weight.nbytes > bucket_size: + raise ValueError( + f"Weight {name} ({weight.nbytes / (1 << 20):.1f} MB) exceeds " + f"bucket size ({bucket_size_mb} MB). Increase bucket_size_mb." + ) + + if offset + weight.nbytes > bucket_size: + torch.cuda.synchronize() if use_gpu_ipc else None + socket.send_pyobj({"bucket_meta": bucket_meta, "is_last": False}) + socket.recv() + bucket_meta = {} + offset = 0 + + bucket_meta[name] = { + "name": name, + "shape": weight.shape, + "dtype": weight.dtype, + "offset": offset, + } + buffer[offset:offset + weight.nbytes].copy_( + weight.view(-1).view(torch.uint8), non_blocking=True + ) + offset += weight.nbytes + + if use_gpu_ipc: + torch.cuda.synchronize() + socket.send_pyobj({"bucket_meta": bucket_meta, "is_last": True}) + socket.recv() + + # Run ZMQ communication and collective_rpc concurrently + loop = asyncio.get_event_loop() + worker_task = asyncio.create_task( + self.engine.collective_rpc( + "update_weights_from_ipc", + kwargs={ + "peft_config": peft_config, + "base_sync_done": base_sync_done, + "use_shm": use_shm, + }, + ) + ) + + await asyncio.sleep(0.1) + + with ThreadPoolExecutor(max_workers=1) as executor: + await loop.run_in_executor(executor, _send_weights_via_zmq) + + await worker_task + + # Clean up + socket.close() + zmq_ctx.term() + del buffer + if shm is not None: + shm.close() + shm.unlink() + del shm + gc.collect() + + elapsed = time.time() - start_time + mode = "LoRA" if base_sync_done and peft_config else "base" + logger.info(f"Updated {len(weights)} {mode} weights via " + f"{'IPC' if use_gpu_ipc else 'SHM'} in {elapsed:.2f}s") async def abort_request(self, request_id: str) -> None: """Abort a specific request.""" await self.engine.abort(request_id) + + async def shutdown(self) -> None: + """Shutdown the vLLM engine and release all resources. + + This method should be called before the process exits to ensure + proper cleanup of the vLLM AsyncLLM engine and its subprocesses. + """ + import gc + + logger.info("Shutting down VLLMEngine...") + + if self.engine is not None: + try: + # vLLM v1 AsyncLLM has shutdown() method + if hasattr(self.engine, 'shutdown'): + await self.engine.shutdown() + elif hasattr(self.engine, 'engine_core'): + # For older versions, try to stop engine core + if hasattr(self.engine.engine_core, 'shutdown'): + await self.engine.engine_core.shutdown() + except Exception as e: + logger.warning(f"Error during engine shutdown: {e}") + finally: + self.engine = None + + # Clear LoRA state + self._user_lora_ids.clear() + self._user_lora_paths.clear() + + # Force garbage collection + gc.collect() + + # Clear CUDA cache if available + if torch.cuda.is_available(): + torch.cuda.empty_cache() + if hasattr(torch.cuda, 'ipc_collect'): + torch.cuda.ipc_collect() + + logger.info("VLLMEngine shutdown complete") \ No newline at end of file diff --git a/src/twinkle/sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler.py index 288cc794..9c55081e 100644 --- a/src/twinkle/sampler/vllm_sampler.py +++ b/src/twinkle/sampler/vllm_sampler.py @@ -25,11 +25,14 @@ import os import threading from dataclasses import asdict -from typing import List, Dict, Any, Union, Optional +from typing import List, Dict, Any, Union, Optional, Literal + +import torch from .base import Sampler from .types import SamplingParams, SampleResponse, SampledSequence from twinkle import remote_function, remote_class, DeviceMesh, requires +from twinkle.checkpoint_engine.mixin import CheckpointEngineMixin from twinkle.utils.platform import Platform from twinkle.data_format import InputFeature, Trajectory from twinkle.patch.vllm_lora_weights import VLLMLoraWeights, TensorLoRARequest @@ -61,7 +64,7 @@ def _collect_sample_responses(results: List[SampleResponse]) -> SampleResponse: @remote_class() -class VLLMSampler(Sampler): +class VLLMSampler(Sampler, CheckpointEngineMixin): """A vLLM-based sampler using VLLMEngine (AsyncLLM). This sampler automatically configures vLLM based on available GPUs. @@ -137,7 +140,11 @@ def __init__( self._create_engine_async(VLLMEngine, model_id, engine_kwargs) ) - VLLMLoraWeights().patch(self) + VLLMLoraWeights()(self) + + # Track LoRA loaded via checkpoint engine sync. + # When set, sampling automatically uses this LoRA request. + self._ckpt_lora_loaded: bool = False self._shutdown_called = False atexit.register(self.shutdown) @@ -229,7 +236,7 @@ async def _sample_single( self, feat: Dict[str, Any], sampling_params: SamplingParams, - adapter_uri: Optional[str] = None, + adapter_path: Optional[str] = None, request_seed: Optional[int] = None, *, num_samples: int = 1, @@ -239,7 +246,7 @@ async def _sample_single( Args: feat: Encoded input features containing 'input_ids' and optionally 'images'/'videos'. sampling_params: Sampling parameters. - adapter_uri: Optional LoRA adapter URI. + adapter_path: Optional LoRA adapter path. request_seed: Optional seed for reproducibility. num_samples: Number of completions to generate for this prompt. @@ -253,11 +260,26 @@ async def _sample_single( images = feat.get('images') videos = feat.get('videos') + # If a LoRA adapter was loaded via checkpoint engine sync and + # no explicit adapter_path is provided, use the synced LoRA. + lora_request = None + if not adapter_path and self._ckpt_lora_loaded: + from vllm.lora.request import LoRARequest + from twinkle.sampler.vllm_worker_extension import ( + VLLM_LORA_INT_ID, VLLM_LORA_NAME, VLLM_LORA_PATH, + ) + lora_request = LoRARequest( + lora_name=VLLM_LORA_NAME, + lora_int_id=VLLM_LORA_INT_ID, + lora_path=VLLM_LORA_PATH, + ) + response = await self.engine.sample( prompt_token_ids=input_ids, sampling_params=sampling_params, num_samples=num_samples, - adapter_uri=adapter_uri, + adapter_path=adapter_path, + lora_request=lora_request, images=images, videos=videos, ) @@ -278,9 +300,10 @@ def sample( inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], sampling_params: Optional[Union[SamplingParams, Dict[str, Any]]] = None, adapter_name: str = '', - adapter_uri: Optional[str] = None, + adapter_path: Optional[str] = None, *, num_samples: int = 1, + return_type: Literal['sample_response', 'trajectory', 'InputFeature'] = 'sample_response', # TODO:enum ) -> SampleResponse: """Sample responses for given inputs. @@ -290,6 +313,7 @@ def sample( - Trajectory: Must contain 'messages'. Requires template to be set. sampling_params: Sampling parameters. adapter_name: Optional LoRA adapter name. + adapter_path: Optional LoRA adapter path. num_samples: Number of completions to generate per input prompt. When > 1, returns num_samples sequences for each input. @@ -329,7 +353,7 @@ def sample( # Sample all inputs in parallel using background event loop async def _sample_all(): tasks = [ - self._sample_single(feat, sampling_params, adapter_uri, num_samples=num_samples) + self._sample_single(feat, sampling_params, adapter_path, num_samples=num_samples) for feat in encoded_inputs ] return await asyncio.gather(*tasks) @@ -390,29 +414,124 @@ def remove_adapter(self, adapter_name: str): self.sample_group.pop(adapter_name) @remote_function(dispatch='all', collect='first') - def sleep(self, level: int = 2) -> None: - """Release GPU memory for colocate mode. - - Call this before training to free up GPU memory used by vLLM. - - Args: - level: Sleep level (1=light, 2=deep). Default 2 releases most memory. + def sleep(self, level: int = 1) -> None: + """ + Release GPU memory for colocate mode. """ self._run_in_loop(self.engine.sleep(level)) @remote_function(dispatch='all', collect='first') - def wake_up(self, tags: List[str] = None, reload_weights: bool = False) -> None: - """Resume GPU memory for colocate mode. - - Call this before sampling to reload weights/KV cache into GPU. - + def wake_up(self, tags: List[str] = None) -> None: + self._run_in_loop(self.engine.wake_up(tags=tags)) + + @remote_function(dispatch='all', collect='first') + def reset_prefix_cache(self): + self._run_in_loop(self.engine.reset_prefix_cache()) + + @remote_function(dispatch='all') + def import_weights_dict( + self, + weights: Dict[str, Any], + peft_config: dict = None, + base_sync_done: bool = False, + ): + """Import weights from a dict via Ray object store. + + Alternative to NCCL checkpoint engine for weight sync. Avoids + ray.util.collective NCCL which can conflict with Megatron's NCCL. + Args: - tags: Optional list of memory types to resume (e.g., ['weights', 'kv_cache']). - If None, resumes all. - reload_weights: If True, reload weights from disk after wake_up. - Required after level 2 sleep which discards weights. + weights: Dict mapping parameter names to tensors. + peft_config: PEFT config dict for LoRA adapter loading. + base_sync_done: If True, load as LoRA adapter. + + Returns: + Number of weights loaded. """ - self._run_in_loop(self.engine.wake_up(tags=tags, reload_weights=reload_weights)) + import gc + from twinkle.utils.framework import Torch + + # Move weights to GPU + device = Torch.get_device(None) + gpu_weights = {} + for name, tensor in weights.items(): + gpu_weights[name] = tensor.to(device, non_blocking=True) + Torch.synchronize() + + # Transfer to vLLM subprocess via engine.update_weights + async def _load(): + await self.engine.update_weights( + gpu_weights, + peft_config=peft_config, + base_sync_done=base_sync_done, + ) + + self._run_in_loop(_load()) + + n_loaded = len(gpu_weights) + del gpu_weights + gc.collect() + Torch.empty_cache() + + logger.info(f"Imported {n_loaded} weights from dict" + f" ({'LoRA' if base_sync_done and peft_config else 'base'} mode)") + return n_loaded + + # ========================================================================= + # Checkpoint Engine — Weight Sync (from CheckpointEngineMixin) + # ========================================================================= + # prepare_checkpoint_engine, init_checkpoint_process_group, and + # finalize_checkpoint_engine are inherited from CheckpointEngineMixin. + # Only receive_weights_via_checkpoint_engine is sampler-specific. + + @remote_function(dispatch='all') + def receive_weights_via_checkpoint_engine( + self, + base_sync_done: bool = False, + peft_config: dict = None, + ): + """Receive weights via NCCL broadcast and load into vLLM. + + Flow: + 1. Receive weights from NCCL broadcast (double-buffered GPU tensors) + 2. Clone into a dict (buffer will be reused for next bucket) + 3. Pass to ``VLLMEngine.update_weights()`` → ``collective_rpc`` → + ``TwinkleWorkerExtension.load_synced_weights()`` in the vLLM + worker subprocess, which handles name conversion and loading. + + Args: + base_sync_done: If True, this is a LoRA-only sync. + peft_config: PEFT config dict for LoRA adapter loading. + + Returns: + Number of weights loaded. + """ + engine = self._get_or_create_checkpoint_engine() + + async def _receive_and_load(): + # Collect weights with original names — name conversion is done + # in the vLLM worker subprocess (TwinkleWorkerExtension). + weights = {} + async for name, tensor in engine.receive_weights(): + weights[name] = tensor.clone() + + if not weights: + return 0 + + await self.engine.update_weights( + weights, + peft_config=peft_config, + base_sync_done=base_sync_done, + ) + + # After LoRA sync, mark that the synced LoRA is loaded so + # sampling automatically uses it. + if base_sync_done and peft_config: + self._ckpt_lora_loaded = True + + return len(weights) + + return self._run_in_loop(_receive_and_load()) def shutdown(self): """Gracefully shutdown the vLLM engine and background event loop. diff --git a/src/twinkle/sampler/vllm_worker_extension.py b/src/twinkle/sampler/vllm_worker_extension.py index d65796ec..91bd6d17 100644 --- a/src/twinkle/sampler/vllm_worker_extension.py +++ b/src/twinkle/sampler/vllm_worker_extension.py @@ -1,86 +1,96 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -"""vLLM Worker Extension for colocated training. +"""vLLM Worker Extension for weight synchronization. -This module provides a Worker extension class that enables direct weight -synchronization from training to vLLM inference workers. +This module provides a Worker extension class that enables weight +synchronization from training to vLLM inference workers via collective_rpc. The extension class is injected into vLLM workers via the `worker_extension_cls` parameter and provides methods for: - Direct weight loading via model.load_weights() -- CUDA IPC weight transfer (colocate mode) -- LoRA adapter loading -- Weight synchronization coordination +- LoRA adapter loading via add_lora() Reference: verl's vLLMColocateWorkerExtension implementation. """ import gc -import hashlib import logging import os -from multiprocessing import shared_memory -from typing import Any, Callable, Dict, List, Optional, Tuple +import platform +import ctypes +import re +import signal +from typing import Dict, List, Optional, Tuple import torch -from twinkle.utils.framework import Framework, Torch +from twinkle.utils.framework import Torch logger = logging.getLogger(__name__) -# TODO: get from tenant context + +def set_death_signal(): + """Kill the current process when the parent process exits.""" + if platform.system() != "Linux": + return + libc = ctypes.CDLL("libc.so.6") + libc.prctl(1, signal.SIGKILL) + if os.getppid() == 1: + os.kill(os.getpid(), signal.SIGKILL) + + +# Constants for the RL training LoRA adapter identity. VLLM_LORA_INT_ID = 1 VLLM_LORA_NAME = "twinkle_lora" VLLM_LORA_PATH = "twinkle_lora_path" -def rebuild_ipc(handle: Tuple[Callable, tuple], device_id: Optional[int] = None) -> torch.Tensor: + +def _rebuild_ipc(handle, device_id: Optional[int] = None) -> torch.Tensor: + """Rebuild CUDA tensor from IPC handle.""" + from torch.multiprocessing.reductions import rebuild_cuda_tensor + func, args = handle list_args = list(args) if device_id is not None: - # Change device ID for different CUDA_VISIBLE_DEVICES list_args[6] = device_id - return func(*list_args) + + if callable(func): + return func(*list_args) + else: + return rebuild_cuda_tensor(*list_args) -def rebuild_shared_memory(name: str, size: int, dtype=torch.uint8): - """Rebuild tensor from shared memory.""" +def _rebuild_shared_memory(name: str, size: int): + """Rebuild tensor from shared memory. Returns (tensor, shm).""" + from multiprocessing import shared_memory shm = shared_memory.SharedMemory(name=name) - tensor = torch.frombuffer(shm.buf[:size], dtype=dtype) + tensor = torch.frombuffer(shm.buf[:size], dtype=torch.uint8) return tensor, shm -def get_device_uuid(device_id: int) -> str: +def _get_device_uuid(device_id: int) -> str: """Get unique device identifier.""" from vllm.platforms import current_platform return current_platform.get_device_uuid(device_id) class TwinkleWorkerExtension: - """ - Extension class for vLLM workers to support weight synchronization. - - This class is designed to be mixed into vLLM's Worker class via the - `worker_extension_cls` parameter. It provides direct access to the - model's load_weights method for efficient weight synchronization. + """Extension class for vLLM workers to support weight synchronization. + + Mixed into vLLM's Worker class via ``worker_extension_cls``. Methods + are called from the VLLMSampler Ray actor through + ``AsyncLLM.collective_rpc()``. Usage: - When creating VLLMEngine, pass: worker_extension_cls="twinkle.sampler.vllm_worker_extension.TwinkleWorkerExtension" """ - - def update_weights_from_tensors( - self, - weights: List[Tuple[str, torch.Tensor]], - ) -> int: - # do we need searialization for tensors? - # do we need bucket loading? - model = self.model_runner.model - - try: - # Call model's load_weights directly - loaded_params = model.load_weights(weights) - logger.info(f"Loaded {len(loaded_params)} weight tensors directly") - return len(loaded_params) - except Exception as e: - logger.error(f"Failed to load weights: {e}") - raise + + def __new__(cls, *args, **kwargs): + from twinkle.patch.vllm_lora_weights import VLLMLoraWeights + set_death_signal() + VLLMLoraWeights()(None) + return super().__new__(cls) + + # ----------------------------------------------------------------- + # Public API — called via collective_rpc from VLLMEngine + # ----------------------------------------------------------------- def update_weights_from_ipc( self, @@ -88,29 +98,26 @@ def update_weights_from_ipc( base_sync_done: bool = False, use_shm: bool = False, ) -> None: - """ - Update weights via CUDA IPC or shared memory. - - This method receives weights from training process via ZMQ + IPC. - Only works in colocate mode (same machine). - + """Receive and load weights via ZMQ + CUDA IPC/SHM. + + Called via ``collective_rpc("update_weights_from_ipc", ...)`` from + :meth:`VLLMEngine.update_weights`. The VLLMEngine sends weights + in buckets over a ZMQ REQ/REP channel backed by CUDA IPC (GPU + tensors) or shared memory (CPU tensors). + Args: - peft_config: If provided, loads as LoRA adapter. - base_sync_done: If True and peft_config provided, replaces existing LoRA. + peft_config: If provided with base_sync_done, loads as LoRA. + base_sync_done: If True and peft_config, replaces existing LoRA. use_shm: If True, use shared memory instead of CUDA IPC. """ import zmq - - # Get device info + if self.device is None: - # ascend vllm does not set device, set here self.device = torch.device(Torch.get_device()) - - # remove existing LoRA if present + if peft_config and base_sync_done: self.remove_lora(VLLM_LORA_INT_ID) - assert self.device is not None # Setup ZMQ socket if not hasattr(self, '_zmq_ctx') or self._zmq_ctx is None: self._zmq_ctx = zmq.Context() @@ -118,46 +125,42 @@ def update_weights_from_ipc( socket.connect(self._get_zmq_handle()) comm_metadata = socket.recv_pyobj() + buffer, shm = None, None if not use_shm: handle = comm_metadata - buffer = rebuild_ipc(handle, self.device.index) - assert buffer.dtype == torch.uint8 + buffer = _rebuild_ipc(handle, self.device.index) else: - shm_name = comm_metadata["name"] - shm_size = comm_metadata["size"] - buffer, shm = rebuild_shared_memory(shm_name, shm_size, dtype=torch.uint8) - - socket.send(b"") # Ready to receive - - # Receive and load weights in buckets + from multiprocessing import shared_memory + buffer, shm = _rebuild_shared_memory( + comm_metadata["name"], comm_metadata["size"], + ) + + socket.send(b"") # Ready + while True: metadata = socket.recv_pyobj() weights = [] - + for name, meta in metadata["bucket_meta"].items(): shape, dtype, offset = meta["shape"], meta["dtype"], meta["offset"] size = dtype.itemsize * shape.numel() tensor = buffer[offset:offset + size].view(dtype=dtype).view(shape) - if not use_shm: - # CUDA IPC: clone to release IPC memory tensor = tensor.clone() else: tensor = tensor.to(self.device) weights.append((name, tensor)) - + Torch.synchronize() socket.send(b"") - - # Load weights - self._update_weights(weights, peft_config=peft_config, base_sync_done=base_sync_done) + + self._load_weights(weights, peft_config=peft_config, base_sync_done=base_sync_done) del weights - + if metadata["is_last"]: break - - # Cleanup + socket.close() del buffer if shm is not None: @@ -166,160 +169,116 @@ def update_weights_from_ipc( gc.collect() Torch.ipc_collect() Torch.empty_cache() - - def _ensure_lora_patch_applied(self): - """Ensure VLLMLoraWeights patch is applied for tensor-based LoRA loading.""" - if getattr(self, '_lora_patch_applied', False): - return - - # Apply patch directly to LRUCacheWorkerLoRAManager._load_adapter - # This is a simplified version that doesn't need sampler reference - # The full VLLMLoraWeights.patch() needs sampler for tokenizer, but - # for _load_adapter patch we only need the core tensor loading logic - - from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager - try: - from vllm.lora.models import LoRAModel - except ImportError: - from vllm.lora.lora_model import LoRAModel - from vllm.lora.utils import get_adapter_absolute_path - from vllm.lora.peft_helper import PEFTHelper - from twinkle.patch.vllm_lora_weights import TensorLoRARequest - - def patched_load_adapter(manager: LRUCacheWorkerLoRAManager, lora_request) -> LoRAModel: - """Load LoRA adapter, supporting tensor-based loading for TensorLoRARequest.""" - supported_lora_modules = manager._adapter_manager.supported_lora_modules - packed_modules_mapping = manager._adapter_manager.packed_modules_mapping - expected_lora_modules: list[str] = [] - for module in supported_lora_modules: - if module in packed_modules_mapping: - expected_lora_modules.extend(packed_modules_mapping[module]) - else: - expected_lora_modules.append(module) - expected_lora_modules = list(set(expected_lora_modules)) - - lora_tensors = None - if isinstance(lora_request, TensorLoRARequest): - peft_config = lora_request.peft_config - lora_tensors = lora_request.lora_tensors - peft_helper = PEFTHelper.from_dict(peft_config) - else: - lora_path = get_adapter_absolute_path(lora_request.lora_path) - peft_helper = PEFTHelper.from_local_dir(lora_path, manager.max_position_embeddings) - - peft_helper.validate_legal(manager.lora_config) - model = manager._adapter_manager.model - hf_to_vllm_mapper = getattr(model, 'hf_to_vllm_mapper', None) - - if isinstance(lora_request, TensorLoRARequest): - lora = manager._lora_model_cls.from_lora_tensors( - lora_model_id=lora_request.lora_int_id, - tensors=lora_tensors, - peft_helper=peft_helper, - device='cpu', - dtype=manager.lora_config.lora_dtype, - embeddings=None, - target_embedding_padding=manager.vocab_size + manager.lora_config.lora_extra_vocab_size, - embedding_modules=manager.embedding_modules, - embedding_padding_modules=manager.embedding_padding_modules, - weights_mapper=hf_to_vllm_mapper, - ) - else: - lora = manager._lora_model_cls.from_local_checkpoint( - lora_path, - expected_lora_modules, - peft_helper=peft_helper, - lora_model_id=lora_request.lora_int_id, - device='cpu', - dtype=manager.lora_config.lora_dtype, - target_embedding_padding=manager.vocab_size + manager.lora_config.lora_extra_vocab_size, - embedding_modules=manager.embedding_modules, - embedding_padding_modules=manager.embedding_padding_modules, - weights_mapper=hf_to_vllm_mapper, - ) - - if lora.extra_vocab_size > manager.lora_config.lora_extra_vocab_size: - raise ValueError(f'LoRA added vocab size {lora.extra_vocab_size} is greater than ' - f'lora_extra_vocab_size {manager.lora_config.lora_extra_vocab_size}.') - return lora - - if not hasattr(LRUCacheWorkerLoRAManager, '_old_load_adapter'): - LRUCacheWorkerLoRAManager._old_load_adapter = LRUCacheWorkerLoRAManager._load_adapter - LRUCacheWorkerLoRAManager._load_adapter = patched_load_adapter - - self._lora_patch_applied = True - logger.info("LoRA tensor loading patch applied to LRUCacheWorkerLoRAManager") - - def _convert_peft_to_vllm_lora_name(self, name: str) -> str: + + def load_synced_weights( + self, + weights: Dict[str, torch.Tensor], + peft_config: Optional[Dict] = None, + base_sync_done: bool = False, + ) -> None: + """Load weights received from the checkpoint engine. + + Called via ``collective_rpc("load_synced_weights", kwargs=...)`` + from :meth:`VLLMEngine.update_weights`. + + Two modes: + - **Base model** (``base_sync_done=False``): + Strips PEFT prefixes and loads via ``model.load_weights()``. + - **LoRA adapter** (``base_sync_done=True`` + ``peft_config``): + Converts names to vLLM LoRA format and loads via ``add_lora()``. + + Args: + weights: Dict mapping weight names to tensors. + peft_config: PEFT config dict for LoRA adapter loading. + base_sync_done: If True with peft_config, load as LoRA adapter. + """ + if self.device is None: + self.device = torch.device(Torch.get_device()) + + weight_list = list(weights.items()) + self._load_weights(weight_list, peft_config=peft_config, base_sync_done=base_sync_done) + + gc.collect() + Torch.empty_cache() + + # ----------------------------------------------------------------- + # Internal helpers + # ----------------------------------------------------------------- + + @staticmethod + def _convert_peft_to_vllm_lora_name(name: str) -> str: """Convert PEFT LoRA weight name to vLLM format. - - PEFT format: base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight - vLLM format: base_model.model.layers.0.self_attn.q_proj.lora_A.weight - - Transformations: - 1. base_model.model.model.X -> base_model.model.X (remove extra model.) - 2. lora_A.default.weight -> lora_A.weight (remove .default) + + PEFT: base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight + vLLM: base_model.model.layers.0.self_attn.q_proj.lora_A.weight """ - # Remove extra 'model.' prefix if present if name.startswith('base_model.model.model.'): name = 'base_model.model.' + name[len('base_model.model.model.'):] - - # Remove '.default' from LoRA weight names - # e.g., lora_A.default.weight -> lora_A.weight - name = name.replace('.lora_A.default.', '.lora_A.') - name = name.replace('.lora_B.default.', '.lora_B.') - + name = re.sub(r'\.lora_A\.[^.]+\.', '.lora_A.', name) + name = re.sub(r'\.lora_B\.[^.]+\.', '.lora_B.', name) return name - - def _update_weights( + + # Stacked parameter mapping matching vLLM Qwen2 model: + # (stacked_param_name, source_shard_name, shard_id) + def _load_weights( self, weights: List[Tuple[str, torch.Tensor]], peft_config: Optional[Dict], base_sync_done: bool, ) -> None: - """Load a batch of weights.""" + """Load a batch of weights into vLLM. + + Two modes: + - LoRA mode (``peft_config`` and ``base_sync_done``): Loads weights as + a tensor-based LoRA adapter via ``add_lora()``. + - Base model mode: Strips PEFT prefixes, merges split weights + (q/k/v_proj -> qkv_proj, gate/up_proj -> gate_up_proj) into vLLM's + stacked format, normalizes prefixes, then loads via direct param copy. + """ if peft_config and base_sync_done: - # LoRA mode - need patch for tensor-based loading - self._ensure_lora_patch_applied() - + # Remove existing LoRA before replacing + self.remove_lora(VLLM_LORA_INT_ID) + from twinkle.patch.vllm_lora_weights import TensorLoRARequest - - # Convert PEFT weight names to vLLM format - converted_weights = {} - for name, tensor in weights: - vllm_name = self._convert_peft_to_vllm_lora_name(name) - converted_weights[vllm_name] = tensor - + + converted = { + self._convert_peft_to_vllm_lora_name(n): t + for n, t in weights + } lora_request = TensorLoRARequest( lora_name=VLLM_LORA_NAME, lora_int_id=VLLM_LORA_INT_ID, lora_path=VLLM_LORA_PATH, peft_config=peft_config, - lora_tensors=converted_weights, + lora_tensors=converted, ) self.add_lora(lora_request) else: - # Strip PEFT prefix from weight names if present - # PEFT uses 'base_model.model.model.' prefix while vLLM expects 'model.' - # Also filter out LoRA-specific weights (lora_A, lora_B) as they should - # be handled separately in LoRA mode - converted_weights = [] + # Base model mode — strip PEFT prefixes and delegate to + # vLLM's model.load_weights() which handles stacked params, + # prefix normalization, and weight_loader internally. + vllm_has_lora = getattr( + getattr(self, 'vllm_config', None), 'lora_config', None, + ) is not None + + converted = [] for name, tensor in weights: - # Skip LoRA-specific weights for base model sync if 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name: continue - # Remove PEFT wrapper prefixes - if name.startswith('base_model.model.model.'): - name = 'model.' + name[len('base_model.model.model.'):] - elif name.startswith('base_model.model.'): - name = name[len('base_model.model.'):] - converted_weights.append((name, tensor)) - # TODO: FP8 support - if converted_weights: - self.model_runner.model.load_weights(converted_weights) - + name = name.removeprefix('model.base_model.model.') + name = name.removeprefix('base_model.model.') + if not vllm_has_lora: + name = name.replace('.base_layer.', '.') + converted.append((name, tensor)) + + if not converted: + return + + self.model_runner.model.load_weights(converted) + logger.info(f"Loaded {len(converted)} base weights") + def _get_zmq_handle(self) -> str: """Get ZMQ handle for IPC communication.""" - if not hasattr(self, 'device_uuid') or not self.device_uuid: - self.device_uuid = get_device_uuid(self.device.index) - return f"ipc:///tmp/twinkle-ipc-{self.device_uuid}.sock" + if not hasattr(self, '_device_uuid') or not self._device_uuid: + self._device_uuid = _get_device_uuid(self.device.index) + return f"ipc:///tmp/twinkle-ipc-{self._device_uuid}.sock" diff --git a/src/twinkle/utils/network.py b/src/twinkle/utils/network.py index f11be938..83d686fb 100644 --- a/src/twinkle/utils/network.py +++ b/src/twinkle/utils/network.py @@ -1,7 +1,19 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import socket +from datetime import timedelta from typing import Optional +import torch + + +def is_valid_ipv6_address(ip: str) -> bool: + """Check if the given string is a valid IPv6 address.""" + try: + socket.inet_pton(socket.AF_INET6, ip) + return True + except socket.error: + return False + def find_node_ip() -> Optional[str]: import psutil @@ -19,15 +31,106 @@ def find_node_ip() -> Optional[str]: return main_ip or virtual_ip -def find_free_port(start_port: Optional[int] = None, retry: int = 100) -> int: +def find_free_port(address: str = '', start_port: Optional[int] = None, retry: int = 100) -> int: + family = socket.AF_INET + if address and is_valid_ipv6_address(address): + family = socket.AF_INET6 if start_port is None: start_port = 0 for port in range(start_port, start_port + retry): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + with socket.socket(family, socket.SOCK_STREAM) as sock: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) try: sock.bind(('', port)) port = sock.getsockname()[1] break except OSError: pass - return port \ No newline at end of file + return port + + + +def stateless_init_process_group( + master_address: str, + master_port: int, + rank: int, + world_size: int, + device: int | torch.device = None, + backend: str = "nccl", + listen_socket: socket.socket = None, + listen_fd: int = None, +): + """Create a stateless process group using vLLM's StatelessProcessGroup. + + vLLM provides `StatelessProcessGroup` to create a process group + without considering the global process group in torch.distributed. + It is recommended to create `StatelessProcessGroup`, and then initialize + the data-plane communication (NCCL/HCCL) between external (train processes) + and vLLM workers. + + Args: + master_address: The IP address of the master (rank 0). + master_port: The port of the master. + rank: The rank of this process. + world_size: Total number of processes. + device: The CUDA device to use. If None, uses current device. + backend: The communication backend ("nccl" or "hccl"). + listen_socket: Optional pre-created listening socket for master (rank 0). + If provided, this socket will be reused instead of creating a new one. + listen_fd: Optional file descriptor of the listening socket. + + Returns: + PyNcclCommunicator or PyHcclCommunicator instance. + """ + from torch.distributed import TCPStore + from vllm.distributed.utils import StatelessProcessGroup + + if backend == "hccl": + from vllm_ascend.distributed.device_communicators.pyhccl import ( + PyHcclCommunicator as Communicator, + ) + else: + from vllm.distributed.device_communicators.pynccl import ( + PyNcclCommunicator as Communicator, + ) + + if device is None: + device = torch.cuda.current_device() if backend == "nccl" else torch.npu.current_device() + + # Create the stateless process group + launch_server = rank == 0 + + if launch_server and listen_socket is None: + # For master, create a listening socket if not provided + if is_valid_ipv6_address(master_address): + listen_socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + else: + listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + listen_socket.bind((master_address, master_port)) + listen_socket.listen() + listen_fd = listen_socket.fileno() + elif launch_server and listen_fd is None: + listen_fd = listen_socket.fileno() + + store = TCPStore( + host_name=master_address, + port=master_port, + world_size=world_size, + is_master=launch_server, + timeout=timedelta(seconds=300), + use_libuv=False, # for compatibility + master_listen_fd=listen_fd, + ) + + pg = StatelessProcessGroup( + rank=rank, + world_size=world_size, + store=store, + socket=listen_socket, + data_expiration_seconds=3600, + ) + + communicator = Communicator(pg, device=device) + return communicator diff --git a/src/twinkle/weight_loader/ipc_loader.py b/src/twinkle/weight_loader/ipc_loader.py index a1eb8db9..19d502d0 100644 --- a/src/twinkle/weight_loader/ipc_loader.py +++ b/src/twinkle/weight_loader/ipc_loader.py @@ -1,39 +1,23 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -"""CUDA IPC Weight Loader for Hybrid mode. +"""Weight Loader for Hybrid mode. -This loader synchronizes weights from training model to vLLM sampler -using CUDA IPC for efficient GPU-to-GPU transfer within the same machine. +Synchronizes weights from training model to vLLM sampler in Hybrid +deployment, where model and sampler live in the same Ray Worker process +but vLLM runs in a subprocess. -Architecture: - Training Model (main process) - │ - │ get_hf_state_dict() -> Generator[(name, tensor)] - ▼ - IPCWeightLoader - │ - │ ZMQ + CUDA IPC (bucket-based) - ▼ - vLLM Worker (subprocess) - │ - └── TwinkleWorkerExtension.update_weights_from_ipc() - -Supported Modes: - - HYBRID: Model and Sampler in same process, vLLM in subprocess (same GPU) - - COLOCATE: Model and Sampler in different processes (same GPU) - needs ray handle support +The weights are collected from the model and passed to vLLM via +``VLLMEngine.update_weights()`` → ``collective_rpc`` → +``TwinkleWorkerExtension.load_synced_weights()`` in the worker subprocess. Limitations: - - CUDA IPC only works on the same physical GPU - - For STANDALONE mode (different GPUs), use NCCLWeightLoader instead + - For STANDALONE mode (different GPUs), use CheckpointEngineManager instead. """ -import concurrent.futures +import asyncio import gc import logging -import os -import uuid -from typing import Any, Dict, Generator, Iterable, Optional, Tuple, Union +from typing import Any, Dict, Iterable, Optional, Tuple import torch -from torch.multiprocessing.reductions import reduce_tensor from twinkle.model.base import TwinkleModel from twinkle.sampler.base import Sampler @@ -43,352 +27,100 @@ logger = logging.getLogger(__name__) -def get_device_uuid(device_id: int) -> str: - """Get unique device identifier.""" - try: - from vllm.platforms import current_platform - # For NPU, handle ASCEND_RT_VISIBLE_DEVICES - if Torch.is_npu_available(): - npu_visible_devices = os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "").split(",") - if device_id < len(npu_visible_devices): - return "NPU-" + npu_visible_devices[device_id] - return current_platform.get_device_uuid(device_id) - except Exception: - # Fallback to random UUID if vllm not available - return uuid.uuid4().hex[:16] - - -def is_ipc_supported() -> bool: - """Check if CUDA/NPU IPC is supported.""" - if Torch.is_gpu_available(): - return True - if Torch.is_npu_available(): - # NPU IPC requires specific versions - # Ascend HDK >= 25.3.rc1 and CANN >= 8.3.RC1 - # TODO: Add version check - return True - return False +class IPCWeightLoader(WeightLoader): + """Weight loader for Hybrid mode. + Collects model weights and transfers them to the vLLM subprocess via + ``VLLMEngine.update_weights()`` (which uses ``collective_rpc``). -class IPCWeightLoader(WeightLoader): - """Weight loader using CUDA IPC for Hybrid mode. - - This loader is designed for scenarios where training model and sampler - are in the same Ray Worker but vLLM runs in subprocess (spawn mode). - - Features: - - CUDA IPC for zero-copy GPU memory sharing - - Bucket-based streaming transfer (avoids OOM for large models) - - Fallback to shared memory when CUDA IPC not supported - Args: - model: Training model instance (TransformersModel/MegatronModel) - sampler: Sampler instance (VLLMSampler) - bucket_size_mb: Size of transfer bucket in MB (default: 512) - use_shm: Force use shared memory instead of CUDA IPC - dtype: Target dtype for weights (default: bfloat16) - - Note: - - For Hybrid mode, model and sampler must be actual objects (not Ray handles) - - For Colocate mode with Ray handles, additional handling is needed - - This loader only supports same-GPU scenarios (CUDA IPC limitation) - - For cross-GPU scenarios (STANDALONE mode), use NCCLWeightLoader - + model: Training model instance (TransformersModel/MegatronModel). + sampler: Sampler instance (VLLMSampler). + dtype: Target dtype for weights (default: bfloat16). + Example: >>> model = TransformersModel(model_id="Qwen/Qwen2.5-0.5B") - >>> sampler = VLLMSampler(model_id="Qwen/Qwen2.5-0.5B", + >>> sampler = VLLMSampler(model_id="Qwen/Qwen2.5-0.5B", ... engine_args={'load_format': 'dummy'}) >>> loader = IPCWeightLoader(model, sampler) - >>> loader.load_weights() # Sync model weights to sampler + >>> loader.load_weights() """ - + def __init__( self, model: TwinkleModel, sampler: Sampler, - bucket_size_mb: int = 512, - use_shm: bool = False, dtype: torch.dtype = torch.bfloat16, + **kwargs, ): self.model = model self.sampler = sampler - self.bucket_size = bucket_size_mb << 20 # Convert to bytes - self.use_shm = use_shm or not is_ipc_supported() self.dtype = dtype - - self._zmq_ctx = None - self._device_uuid = None self.base_sync_done = False - if self.use_shm: - logger.warning( - "IPC is not supported on your devices. Falling back to shared memory for weight transfer, " - "which may cause performance degradation." - ) - - @property - def device_uuid(self) -> str: - """Get or compute device UUID.""" - if self._device_uuid is None: - device_id = Torch.get_current_device() - if isinstance(device_id, str): - device_id = 0 - self._device_uuid = get_device_uuid(device_id) - return self._device_uuid - - @property - def zmq_handle(self) -> str: - """Get ZMQ IPC socket address.""" - return f"ipc:///tmp/twinkle-ipc-{self.device_uuid}.sock" - + def load_weights(self, adapter_name: str = '', peft_config: Optional[Dict] = None): - """Sync weights from model to sampler via CUDA IPC. - - This is the main entry point for weight synchronization. - + """Sync weights from model to sampler. + Args: - adapter_name: Name of the adapter (for LoRA) - peft_config: PEFT config for LoRA mode. When provided with base_sync_done=True, - only LoRA weights are synced (assuming base model is already loaded in vLLM). + adapter_name: Name of the adapter (for LoRA, reserved). + peft_config: PEFT config dict for LoRA adapter loading. """ - import zmq - - # Get weights iterator from training model - # For TransformersModel: returns dict of {name: tensor} - # For MegatronModel: get_hf_state_dict() returns Generator[(name, tensor)] - # Using iterator directly avoids OOM for large models - weights_source = self._get_weights_iterator(adapter_name) - - logger.info("Starting CUDA IPC weight sync...") - - # Step 1: Setup ZMQ sender FIRST (bind before worker connects) - if self._zmq_ctx is None: - self._zmq_ctx = zmq.Context() - socket = self._zmq_ctx.socket(zmq.REQ) - socket.bind(self.zmq_handle) - logger.debug(f"ZMQ socket bound to {self.zmq_handle}") - - # Step 2: Trigger vLLM worker to start receiving (non-blocking) - # Worker will connect to ZMQ and wait for data - receiver_future = self._trigger_receiver(peft_config, base_sync_done=self.base_sync_done) - - # Give worker time to connect import time - time.sleep(0.5) - - try: - # Step 3: Create transfer buffer and send handle - buffer, shm = self._create_buffer(socket) - - # Step 4: Send weights in buckets (streaming, no full list) - count = self._send_weights_in_buckets(socket, buffer, weights_source) - - # Step 5: Wait for receiver to complete processing - # This ensures the collective_rpc is fully done before returning - try: - receiver_future.result(timeout=60) # 60s timeout - except Exception as e: - logger.warning(f"Receiver future completed with: {e}") - - # Step 6: Clear KV cache after weight update - # This is necessary because the model weights have changed - self._clear_kv_cache() - - logger.info(f"CUDA IPC weight sync completed: {count} tensors") - - finally: - # Cleanup - socket.close() - if buffer is not None: - del buffer - if shm is not None: - shm.close() - shm.unlink() - gc.collect() - Torch.ipc_collect() - Torch.empty_cache() - - def _get_weights_iterator(self, adapter_name: str = '') -> Iterable[Tuple[str, torch.Tensor]]: - """Get weights iterator from model. - - This method handles both TransformersModel (dict) and MegatronModel (generator). - For MegatronModel, it uses get_hf_state_dict() which returns a generator - that converts Megatron format to HF format on-the-fly. - - Args: - adapter_name: Name of the adapter (for LoRA) - - Returns: - Iterable of (name, tensor) pairs - """ - # Check if model has get_hf_state_dict (MegatronModel) - if hasattr(self.model, 'get_hf_state_dict'): - # MegatronModel: returns generator that converts to HF format - return self.model.get_hf_state_dict(adapter_name=adapter_name) - else: - # TransformersModel: returns dict - state_dict = self.model.get_state_dict(adapter_name=adapter_name) - if isinstance(state_dict, dict): - return state_dict.items() - else: - # Already an iterator/generator - return state_dict - - def _trigger_receiver(self, peft_config: Optional[Dict], base_sync_done: bool = False) -> "concurrent.futures.Future": - """Trigger vLLM worker to start receiving weights. - - This calls the sampler's engine to invoke collective_rpc on all - vLLM workers to start the update_weights_from_ipc method. - - IMPORTANT: This method triggers the receiver in a non-blocking way. - The collective_rpc call starts the worker method but doesn't wait for - it to complete, allowing the sender to start sending data. - - Args: - peft_config: PEFT config for LoRA mode. - base_sync_done: If True and peft_config provided, only sync LoRA weights. - If False and peft_config provided, sync base model weights first. - - Returns: - Future that can be awaited after sending weights. - """ - import asyncio - - # Access vLLM engine's collective_rpc through sampler - # VLLMSampler -> VLLMEngine -> AsyncLLM -> collective_rpc + + start_time = time.time() + + # Collect weights from training model + weights = {} + for name, tensor in self._get_weights_iterator(adapter_name): + tensor = Torch.to_local_tensor(tensor) + weights[name] = tensor.to(self.dtype, non_blocking=True) + Torch.synchronize() + + # Transfer to vLLM subprocess via collective_rpc engine = self.sampler.engine - - # Run the async trigger in the sampler's event loop - async def _trigger(): - # collective_rpc call - worker will start receiving and wait on ZMQ socket - return await engine.engine.collective_rpc( - "update_weights_from_ipc", - kwargs={ - "peft_config": peft_config, - "use_shm": self.use_shm, - "base_sync_done": base_sync_done, - }, - ) - - # Schedule the task without waiting for completion - # This allows the sender to proceed immediately future = asyncio.run_coroutine_threadsafe( - _trigger(), - self.sampler._async_loop + engine.update_weights( + weights, + peft_config=peft_config, + base_sync_done=self.base_sync_done, + ), + self.sampler._async_loop, ) - # Return future so caller can wait after sending weights - return future - - def _create_buffer(self, socket) -> Tuple[torch.Tensor, Any]: - """Create transfer buffer and send handle to receiver.""" - buffer = None - shm = None - - if not self.use_shm: - # CUDA IPC mode - device = Torch.get_device(None) - buffer = torch.empty(self.bucket_size, dtype=torch.uint8, device=device) - handle = reduce_tensor(buffer) - socket.send_pyobj(handle) - else: - # Shared memory mode - from multiprocessing import shared_memory - - shm_name = f"twinkle_weights_{uuid.uuid4().hex}" - shm = shared_memory.SharedMemory(name=shm_name, create=True, size=self.bucket_size) - buffer = torch.frombuffer(shm.buf, dtype=torch.uint8) - socket.send_pyobj({"name": shm_name, "size": self.bucket_size}) - - socket.recv() # Wait for receiver ready - return buffer, shm - - def _send_weights_in_buckets( - self, - socket, - buffer: torch.Tensor, - weights: Iterable[Tuple[str, torch.Tensor]], - ) -> int: - """Send weights in buckets via streaming. - - This method processes weights one by one without loading all into memory, - which is critical for large models to avoid OOM. - - Args: - socket: ZMQ socket for communication - buffer: CUDA IPC buffer for weight transfer - weights: Iterable of (name, tensor) pairs - - Returns: - Number of tensors sent + future.result(timeout=120) + + # Clear KV cache since model weights changed + self._clear_kv_cache() + + del weights + gc.collect() + Torch.empty_cache() + + elapsed = time.time() - start_time + logger.info(f"Weight sync completed in {elapsed:.2f}s") + + def _get_weights_iterator(self, adapter_name: str = '') -> Iterable[Tuple[str, torch.Tensor]]: + """Get weights iterator from the local model object. + + Supports TransformersModel (state_dict → dict) and + MegatronModel (get_hf_state_dict → generator). """ - offset = 0 - bucket_meta = {} - count = 0 - - for name, weight in weights: - # Convert DTensor to local tensor if needed - weight = Torch.to_local_tensor(weight) - - # Convert to target dtype - weight = weight.to(self.dtype, non_blocking=True) - - # Check if bucket is full - if offset + weight.nbytes > self.bucket_size: - Torch.synchronize() - socket.send_pyobj({"bucket_meta": bucket_meta, "is_last": False}) - socket.recv() - bucket_meta = {} - offset = 0 - - # Validate weight fits in bucket - if weight.nbytes > self.bucket_size: - raise ValueError( - f"Weight '{name}' ({weight.shape}, {weight.dtype}) is too large " - f"({weight.nbytes / 1e6:.1f}MB) for bucket ({self.bucket_size / 1e6:.1f}MB). " - f"Increase bucket_size_mb." - ) - - # Add weight to bucket - bucket_meta[name] = { - "name": name, - "shape": weight.shape, - "dtype": weight.dtype, - "offset": offset, - } - buffer[offset:offset + weight.nbytes].copy_( - weight.view(-1).view(torch.uint8), non_blocking=True - ) - offset += weight.nbytes - count += 1 - - # Send final bucket - Torch.synchronize() - socket.send_pyobj({"bucket_meta": bucket_meta, "is_last": True}) - socket.recv() - - return count - + if hasattr(self.model, 'get_hf_state_dict'): + return self.model.get_hf_state_dict() + else: + return self.model.state_dict() + def _clear_kv_cache(self) -> None: - """Clear KV cache after weight update. - - This is necessary because the model weights have changed and - any cached KV pairs are now invalid. - """ - import asyncio - + """Clear KV cache after weight update.""" engine = self.sampler.engine - - async def _clear(): - await engine.clear_kv_cache() - try: future = asyncio.run_coroutine_threadsafe( - _clear(), - self.sampler._async_loop + engine.clear_kv_cache(), + self.sampler._async_loop, ) future.result(timeout=10) except Exception as e: logger.warning(f"Failed to clear KV cache: {e}") - + def __call__( self, model: TwinkleModel = None, diff --git a/tests/sampler/test_megatron_weight_sync.py b/tests/sampler/test_megatron_weight_sync.py new file mode 100644 index 00000000..403c57be --- /dev/null +++ b/tests/sampler/test_megatron_weight_sync.py @@ -0,0 +1,290 @@ +#!/usr/bin/env python +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Test STANDALONE weight synchronization between MegatronModel and vLLM sampler. + +This script tests the checkpoint engine weight sync flow when the training +model uses Megatron-Core (with TP/PP parallelism) and the inference sampler +uses vLLM: + + 1. Create MegatronModel (with real weights, TP=2) and VLLMSampler (with dummy weights) + 2. Sample with dummy weights → garbage output + 3. Sync weights from MegatronModel → VLLMSampler via CheckpointEngineManager + 4. Sample with synced weights → coherent output + 5. Verify that outputs differ (proof that weights were synced) + +The Megatron bridge internally handles TP allgather during export, converting +Megatron-format weights to HuggingFace format on-the-fly. + +Usage: + # 2 Megatron GPUs (TP=2) + 2 sampler GPUs (4 GPUs total, using GPUs 4-7) + CUDA_VISIBLE_DEVICES=4,5,6,7 python tests/sampler/test_megatron_weight_sync.py + + # 2 Megatron GPUs (TP=2) + 1 sampler GPU (3 GPUs total) + CUDA_VISIBLE_DEVICES=4,5,6 python tests/sampler/test_megatron_weight_sync.py --sampler-gpus 1 + + # Custom model + CUDA_VISIBLE_DEVICES=4,5,6,7 TEST_MODEL_ID=Qwen/Qwen2.5-7B-Instruct \ + python tests/sampler/test_megatron_weight_sync.py --tp-size 2 +""" + +import os +import sys +import time +import argparse +import logging + +# Must set before importing anything +os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' +os.environ['VLLM_LOGGING_LEVEL'] = 'WARNING' +# Prevent hanging during NCCL weight sync in disaggregated mode +os.environ['NCCL_CUMEM_ENABLE'] = '0' + +# Model configuration — use a small model for testing +MODEL_ID = os.environ.get('TEST_MODEL_ID', 'Qwen/Qwen2.5-0.5B-Instruct') + +logger = logging.getLogger(__name__) + + +def log(msg): + """Print message with timestamp.""" + import datetime + ts = datetime.datetime.now().strftime("%H:%M:%S") + print(f"[{ts}] {msg}", flush=True) + + +def wait_result(result): + """Resolve lazy collect / ray object ref to actual value.""" + if hasattr(result, '_is_lazy_collect') and result._is_lazy_collect: + return result() + if hasattr(result, 'wait'): + return result.wait() + if callable(result) and hasattr(result, '_get_result'): + return result() + return result + + +def get_model_path(): + """Resolve model_id to a local cache path (for offline environments).""" + try: + from modelscope.hub.snapshot_download import snapshot_download + _cache = snapshot_download(MODEL_ID, local_files_only=True) + if _cache: + return _cache + except Exception: + pass + return MODEL_ID + + +# ============================================================================= +# Test: Megatron Standalone Weight Sync +# ============================================================================= + +def test_megatron_weight_sync( + model_gpus: int = 2, + sampler_gpus: int = 2, + tp_size: int = 2, + pp_size: int = 1, +): + """Test weight sync from MegatronModel to VLLMSampler via NCCL broadcast. + + Architecture: + Model workers : GPU 0 .. model_gpus-1 (Megatron, TP=tp_size, real weights) + Sampler workers: GPU model_gpus .. total-1 (vLLM, dummy weights) + + The Megatron bridge converts weights from Megatron format to HuggingFace + format during export. TP allgather is handled internally by the bridge. + Only model_actor[0] broadcasts via the checkpoint engine's NCCL group; + other model actors consume the generator (triggering TP allgather) but + do not participate in the broadcast. + """ + import twinkle + from twinkle import DeviceGroup, DeviceMesh + from twinkle.model import MegatronModel + from twinkle.sampler import VLLMSampler + from twinkle.template import Template + from twinkle.checkpoint_engine import CheckpointEngineManager + from twinkle.data_format import Trajectory + from twinkle.sampler.types import SamplingParams + + total_gpus = model_gpus + sampler_gpus + model_path = get_model_path() + + # Validate parallelism config + assert model_gpus == tp_size * pp_size, ( + f"model_gpus ({model_gpus}) must equal tp_size * pp_size " + f"({tp_size} * {pp_size} = {tp_size * pp_size})" + ) + + log("=" * 70) + log(f"TEST: Megatron Standalone Weight Sync") + log(f" Model : GPU 0-{model_gpus - 1} ({model_gpus} workers, TP={tp_size}, PP={pp_size})") + log(f" Sampler: GPU {model_gpus}-{total_gpus - 1} ({sampler_gpus} workers)") + log(f" Model : {model_path}") + log("=" * 70) + + # ── Initialize Twinkle in Ray mode ──────────────────────────────── + twinkle.initialize( + mode='ray', + nproc_per_node=total_gpus, + groups=[ + DeviceGroup( + name='model', + ranks=list(range(model_gpus)), + device_type='GPU', + gpus_per_worker=1, + ), + DeviceGroup( + name='sampler', + ranks=list(range(model_gpus, total_gpus)), + device_type='GPU', + gpus_per_worker=1, + ), + ], + ) + + try: + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + except Exception: + from modelscope import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + # ── Create MegatronModel (real weights) ──────────────────────────── + log("\nCreating MegatronModel (real weights)...") + model_device_mesh = DeviceMesh.from_sizes( + world_size=model_gpus, + dp_size=model_gpus // (tp_size * pp_size), + tp_size=tp_size, + pp_size=pp_size, + ) + model = MegatronModel( + model_id=model_path, + device_mesh=model_device_mesh, + mixed_precision='bf16', + sequence_parallel=(tp_size > 1), + remote_group='model', + ) + log(" MegatronModel created successfully") + + # ── Create Sampler (dummy weights) ──────────────────────────────── + log("Creating Sampler (dummy weights)...") + sampler = VLLMSampler( + model_id=model_path, + engine_args={ + 'load_format': 'dummy', + 'gpu_memory_utilization': 0.3, + 'max_model_len': 256, + 'enforce_eager': True, + 'enable_sleep_mode': True, + 'enable_lora': False, + }, + device_mesh=DeviceMesh.from_sizes(world_size=sampler_gpus, dp_size=sampler_gpus), + remote_group='sampler', + ) + sampler.set_template(Template, model_id=model_path) + log(" VLLMSampler created successfully") + + # Wait for vLLM initialization + log("Waiting for vLLM initialization...") + time.sleep(5) + + # ── Helper: sample one prompt ───────────────────────────────────── + def do_sample(prompt: str, max_tokens: int = 32) -> str: + traj = Trajectory(messages=[{'role': 'user', 'content': prompt}]) + response = wait_result( + sampler.sample(traj, SamplingParams(max_tokens=max_tokens, temperature=0.0)) + ) + if response and response.sequences: + tokens = response.sequences[0].tokens + if hasattr(tokens, 'tolist'): + tokens = tokens.tolist() + return tokenizer.decode(tokens, skip_special_tokens=True) + return "" + + # ── Sample BEFORE sync (dummy weights → garbage) ────────────────── + log("\n--- Sampling BEFORE weight sync (dummy weights) ---") + text_before = do_sample("What is 2+2?") + log(f" Output: '{text_before[:100]}'") + + # ── Sync weights: MegatronModel → Sampler via NCCL ──────────────── + log("\n--- Syncing weights via CheckpointEngineManager ---") + manager = CheckpointEngineManager( + model=model, + sampler=sampler, + ) + + sync_start = time.time() + manager.sync_weights() + sampler.reset_prefix_cache() + sync_time = time.time() - sync_start + log(f" Weight sync completed in {sync_time:.2f}s") + + # ── Sample AFTER sync (real weights → coherent) ─────────────────── + log("\n--- Sampling AFTER weight sync (real weights) ---") + text_after = do_sample("What is 2+2?") + log(f" Output: '{text_after[:100]}'") + + # ── Verification ────────────────────────────────────────────────── + log("\n" + "=" * 70) + log("VERIFICATION") + log("=" * 70) + + outputs_differ = text_before != text_after + log(f" Outputs differ after sync: {outputs_differ}") + + if outputs_differ: + log(" PASS: Weight sync verified — outputs changed after sync.") + if "4" in text_after.lower() or "four" in text_after.lower(): + log(" BONUS: Model correctly answered '2+2' question!") + else: + log(" FAIL: Outputs are identical — weight sync may have failed.") + + sampler.shutdown() + return outputs_differ + + +# ============================================================================= +# Main +# ============================================================================= + +def main(): + parser = argparse.ArgumentParser(description='Test Megatron standalone weight synchronization') + parser.add_argument('--model-gpus', type=int, default=2, + help='Number of GPUs for Megatron model (default: 2)') + parser.add_argument('--sampler-gpus', type=int, default=2, + help='Number of GPUs for vLLM sampler (default: 2)') + parser.add_argument('--tp-size', type=int, default=2, + help='Tensor parallel size (default: 2)') + parser.add_argument('--pp-size', type=int, default=1, + help='Pipeline parallel size (default: 1)') + args = parser.parse_args() + + log(f"Starting Megatron standalone weight sync test...") + log(f" Model GPUs: {args.model_gpus}") + log(f" Sampler GPUs: {args.sampler_gpus}") + log(f" TP size: {args.tp_size}") + log(f" PP size: {args.pp_size}") + log(f" Model ID: {MODEL_ID}") + + try: + success = test_megatron_weight_sync( + model_gpus=args.model_gpus, + sampler_gpus=args.sampler_gpus, + tp_size=args.tp_size, + pp_size=args.pp_size, + ) + except Exception as e: + log(f"\nTest failed with exception: {e}") + import traceback + traceback.print_exc() + success = False + + log("\n" + "=" * 70) + log(f"RESULT: {'PASS' if success else 'FAIL'}") + log("=" * 70) + + return 0 if success else 1 + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/tests/sampler/test_weight_sync.py b/tests/sampler/test_weight_sync.py index 0e744b83..7dd540ce 100644 --- a/tests/sampler/test_weight_sync.py +++ b/tests/sampler/test_weight_sync.py @@ -1,47 +1,45 @@ #!/usr/bin/env python # Copyright (c) ModelScope Contributors. All rights reserved. -""" -Test weight synchronization between training model and vLLM sampler. +"""Test STANDALONE weight synchronization between training model and vLLM sampler. + +This script serves as both a test and a minimal demo of the weight sync flow +used during RL training: + + 1. Create TransformersModel (with real weights) and VLLMSampler (with dummy weights) + 2. Sample with dummy weights → garbage output + 3. Sync weights from Model → Sampler via CheckpointEngineManager (NCCL broadcast) + 4. Sample with synced weights → coherent output + 5. Verify that outputs differ (proof that weights were synced) -This test verifies that weights can be correctly synchronized from a -TransformersModel to a VLLMSampler using IPCWeightLoader in Hybrid mode. +Usage: + # 2 model GPUs + 2 sampler GPUs (requires 4 GPUs) + CUDA_VISIBLE_DEVICES=0,1,2,3 python tests/sampler/test_weight_sync.py --model-gpus 2 --sampler-gpus 2 -Test Configuration: -- 4 GPUs with DP=4, TP=1 -- Each GPU runs one HybridModelSamplerActor with model+sampler -- Weight sync via CUDA IPC within each GPU + # 1 model GPU + 1 sampler GPU (requires 2 GPUs) + CUDA_VISIBLE_DEVICES=0,1 python tests/sampler/test_weight_sync.py + +Note: + - Requires Ray and multiple GPUs + - Set TEST_MODEL_ID environment variable to use a different model """ + import os import sys import time +import argparse +import logging # Must set before importing anything os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' os.environ['VLLM_LOGGING_LEVEL'] = 'WARNING' -os.environ['CUDA_VISIBLE_DEVICES'] = '2,3' -# Model configuration - use a small model for testing -MODEL_ID = 'Qwen/Qwen2.5-0.5B-Instruct' -WORLD_SIZE = 2 - -from transformers import AutoTokenizer +# Prevent hanging during NCCL weight sync in disaggregated mode +# See: https://docs.vllm.ai/en/latest/usage/troubleshooting.html#known-issues +os.environ['NCCL_CUMEM_ENABLE'] = '0' -from twinkle import remote_class, remote_function, DeviceMesh, DeviceGroup -from twinkle.sampler import VLLMSampler -from twinkle.sampler.types import SamplingParams -from twinkle.template import Template -from twinkle.data_format import Trajectory -from twinkle.weight_loader import IPCWeightLoader -from twinkle.model.transformers import TransformersModel +# Model configuration — use a small model for testing +MODEL_ID = os.environ.get('TEST_MODEL_ID', 'Qwen/Qwen2.5-0.5B-Instruct') - -# Resolve to local cache -try: - from modelscope.hub.snapshot_download import snapshot_download - _cache = snapshot_download(MODEL_ID, local_files_only=True) - if _cache: - MODEL_ID = _cache -except Exception: - pass +logger = logging.getLogger(__name__) def log(msg): @@ -52,7 +50,7 @@ def log(msg): def wait_result(result): - """Wait for result if it's a LazyCollect object.""" + """Resolve lazy collect / ray object ref to actual value.""" if hasattr(result, '_is_lazy_collect') and result._is_lazy_collect: return result() if hasattr(result, 'wait'): @@ -62,275 +60,193 @@ def wait_result(result): return result -@remote_class() -class HybridModelSamplerActor: - """Hybrid actor that fuses training model and sampler in same process. - - This simulates the Hybrid mode where: - - Training model (TransformersModel) holds the real weights - - vLLM Sampler starts with dummy/random weights - - Weight sync happens via IPCWeightLoader (CUDA IPC + ZMQ) - """ - - def __init__( - self, - model_id: str, - device_mesh: DeviceMesh = None, - remote_group: str = None, - **kwargs - ): - import torch - rank = torch.cuda.current_device() if torch.cuda.is_available() else 0 - log(f"[Rank {rank}] Initializing HybridModelSamplerActor...") - - # Initialize sampler with dummy weights (random initialization) - self.sampler = VLLMSampler( - model_id=model_id, - engine_args={ - 'load_format': 'dummy', # Random weights - 'gpu_memory_utilization': 0.3, - 'max_model_len': 256, - 'enforce_eager': True, - 'enable_sleep_mode': True, - }, - ) - self.sampler.set_template(Template, model_id=model_id) - log(f"[Rank {rank}] VLLMSampler initialized with dummy weights") - - # Initialize training model with real weights - self.model = TransformersModel(model_id=model_id, device_mesh=device_mesh) - log(f"[Rank {rank}] TransformersModel initialized with real weights") - - # Initialize weight loader for Hybrid mode (CUDA IPC) - self.weight_loader = IPCWeightLoader( - model=self.model, - sampler=self.sampler, - bucket_size_mb=512, - ) - log(f"[Rank {rank}] IPCWeightLoader initialized") - - @remote_function(dispatch='all', collect='first') - def sync_weights(self, adapter_name: str = ''): - """Sync weights from training model to sampler via CUDA IPC.""" - import torch - rank = torch.cuda.current_device() if torch.cuda.is_available() else 0 - log(f"[Rank {rank}] Starting weight sync...") - start = time.time() - self.weight_loader.load_weights(adapter_name=adapter_name) - elapsed = time.time() - start - log(f"[Rank {rank}] Weight sync completed in {elapsed:.2f}s") - return elapsed - - @remote_function(dispatch='all', collect='first') - def sample_text(self, prompt: str, max_tokens: int = 64) -> dict: - """Sample text from the model and return result info.""" - import torch - rank = torch.cuda.current_device() if torch.cuda.is_available() else 0 - - traj = Trajectory(messages=[{'role': 'user', 'content': prompt}]) - response = self.sampler.sample( - traj, - SamplingParams(max_tokens=max_tokens, temperature=0.0) - ) - - if response and hasattr(response, 'sequences') and response.sequences: - tokens = response.sequences[0].tokens - if hasattr(tokens, 'tolist'): - tokens = tokens.tolist() - else: - tokens = list(tokens) if tokens else [] - return { - 'rank': rank, - 'tokens': tokens, - 'num_tokens': len(tokens), - } - return {'rank': rank, 'tokens': [], 'num_tokens': 0} - - @remote_function(dispatch='all', collect='first') - def get_model_info(self) -> dict: - """Get model information.""" - import torch - rank = torch.cuda.current_device() if torch.cuda.is_available() else 0 - state_dict = self.model.get_state_dict() - if isinstance(state_dict, dict): - num_params = len(state_dict) - else: - num_params = sum(1 for _ in state_dict) - return { - 'rank': rank, - 'num_params': num_params, - 'model_id': self.model.model_id, - } - - -def test_hybrid_weight_sync(): - """Test weight sync in Hybrid mode with WORLD_SIZE GPUs (DP=WORLD_SIZE, TP=1). - - Each GPU runs one actor with: - - TransformersModel with real weights - - VLLMSampler with dummy weights initially - - IPCWeightLoader for weight sync - - Test verifies: - 1. Before sync: sampler produces garbage output (random weights) - 2. After sync: sampler produces correct output (real weights) +def get_model_path(): + """Resolve model_id to a local cache path (for offline environments).""" + try: + from modelscope.hub.snapshot_download import snapshot_download + _cache = snapshot_download(MODEL_ID, local_files_only=True) + if _cache: + return _cache + except Exception: + pass + return MODEL_ID + + +# ============================================================================= +# Test: Standalone Weight Sync +# ============================================================================= + +def test_standalone_weight_sync(model_gpus: int = 1, sampler_gpus: int = 1): + """Test weight sync in STANDALONE mode (model and sampler on different GPUs). + + Architecture: + Model workers : GPU 0 .. model_gpus-1 (training, real weights) + Sampler workers: GPU model_gpus .. total-1 (inference, dummy weights) + + Weight sync flow (managed by CheckpointEngineManager): + 1. prepare — allocate NCCL buffers, ZMQ metadata server + 2. build_topology — model[0]→rank0 (source), sampler→rank1..N + 3. init_process_group — temporary NCCL group + 4. send / receive — NCCL broadcast (parallel) + 5. finalize — release buffers, close ZMQ """ import twinkle - + from twinkle import DeviceGroup, DeviceMesh + from twinkle.model.transformers import TransformersModel + from twinkle.sampler import VLLMSampler + from twinkle.template import Template + from twinkle.checkpoint_engine import CheckpointEngineManager + from transformers import AutoTokenizer + from twinkle.data_format import Trajectory + from twinkle.sampler.types import SamplingParams + + total_gpus = model_gpus + sampler_gpus + model_path = get_model_path() + log("=" * 70) - log(f"TEST: Hybrid Weight Sync ({WORLD_SIZE} GPU, DP={WORLD_SIZE}, TP=1)") + log(f"TEST: Standalone Weight Sync") + log(f" Model : GPU 0-{model_gpus - 1} ({model_gpus} workers)") + log(f" Sampler: GPU {model_gpus}-{total_gpus - 1} ({sampler_gpus} workers)") + log(f" Model : {model_path}") log("=" * 70) - - # Initialize with WORLD_SIZE GPUs + + # ── Initialize Twinkle in Ray mode ──────────────────────────────── twinkle.initialize( mode='ray', - nproc_per_node=WORLD_SIZE, + nproc_per_node=total_gpus, groups=[ DeviceGroup( - name='hybrid', - ranks=[i for i in range(WORLD_SIZE)], - device_type='GPU', - gpus_per_worker=1, # Each worker gets 1 GPU (TP=1) + name='model', + ranks=list(range(model_gpus)), + device_type='GPU', + gpus_per_worker=1, + ), + DeviceGroup( + name='sampler', + ranks=list(range(model_gpus, total_gpus)), + device_type='GPU', + gpus_per_worker=1, ), ], ) - tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) - test_prompts = [ - "What is 2+2?", - "What is the capital of France?", - "What is 10*5?", - "Hello, my name is", - ] - - # Create hybrid actor (will spawn WORLD_SIZE instances, one per GPU) - log(f"Creating HybridModelSamplerActor on {WORLD_SIZE} GPUs...") - actor = HybridModelSamplerActor( - model_id=MODEL_ID, - device_mesh=DeviceMesh.from_sizes(world_size=WORLD_SIZE, dp_size=WORLD_SIZE), - remote_group='hybrid', + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + # ── Create Model (real weights) ─────────────────────────────────── + log("\nCreating Model (real weights)...") + model = TransformersModel( + model_id=model_path, + device_mesh=DeviceMesh.from_sizes(world_size=model_gpus, dp_size=model_gpus), + remote_group='model', + ) + + # ── Create Sampler (dummy weights) ──────────────────────────────── + log("Creating Sampler (dummy weights)...") + sampler = VLLMSampler( + model_id=model_path, + engine_args={ + 'load_format': 'dummy', # start with random weights + 'gpu_memory_utilization': 0.3, + 'max_model_len': 256, + 'enforce_eager': True, + 'enable_sleep_mode': True, + 'enable_lora': False, + }, + device_mesh=DeviceMesh.from_sizes(world_size=sampler_gpus, dp_size=sampler_gpus), + remote_group='sampler', + ) + sampler.set_template(Template, model_id=model_path) + + # Wait for vLLM initialization + log("Waiting for vLLM initialization...") + time.sleep(3) + + # ── Helper: sample one prompt ───────────────────────────────────── + def do_sample(prompt: str, max_tokens: int = 32) -> str: + traj = Trajectory(messages=[{'role': 'user', 'content': prompt}]) + response = wait_result( + sampler.sample(traj, SamplingParams(max_tokens=max_tokens, temperature=0.0)) + ) + if response and response.sequences: + tokens = response.sequences[0].tokens + if hasattr(tokens, 'tolist'): + tokens = tokens.tolist() + return tokenizer.decode(tokens, skip_special_tokens=True) + return "" + + # ── Sample BEFORE sync (dummy weights → garbage) ────────────────── + log("\n--- Sampling BEFORE weight sync (dummy weights) ---") + text_before = do_sample("What is 2+2?") + log(f" Output: '{text_before[:100]}'") + + # ── Sync weights: Model → Sampler via NCCL ──────────────────────── + log("\n--- Syncing weights via CheckpointEngineManager ---") + manager = CheckpointEngineManager( + model=model, + sampler=sampler, ) - - # Wait for initialization and get model info - log("Waiting for actor initialization...") - model_info = wait_result(actor.get_model_info()) - log(f"Model info: {model_info}") - - # Test 1: Sample BEFORE weight sync (should produce garbage) - log("\n" + "-" * 50) - log("STEP 1: Sampling BEFORE weight sync (dummy weights)") - log("-" * 50) - - results_before = {} - for i, prompt in enumerate(test_prompts): - result = wait_result(actor.sample_text(prompt)) - text = tokenizer.decode(result['tokens'], skip_special_tokens=True) if result['tokens'] else "" - results_before[prompt] = { - 'tokens': result['tokens'], - 'text': text, - } - log(f" Prompt {i+1}: '{prompt}'") - log(f" Output: '{text[:80]}...' ({result['num_tokens']} tokens)") - - # Test 2: Sync weights - log("\n" + "-" * 50) - log("STEP 2: Syncing weights via IPCWeightLoader") - log("-" * 50) - + sync_start = time.time() - sync_result = wait_result(actor.sync_weights()) - sync_elapsed = time.time() - sync_start - log(f"Weight sync completed in {sync_elapsed:.2f}s") - - # Test 3: Sample AFTER weight sync (should produce correct output) - log("\n" + "-" * 50) - log("STEP 3: Sampling AFTER weight sync (real weights)") - log("-" * 50) - - results_after = {} - for i, prompt in enumerate(test_prompts): - result = wait_result(actor.sample_text(prompt)) - text = tokenizer.decode(result['tokens'], skip_special_tokens=True) if result['tokens'] else "" - results_after[prompt] = { - 'tokens': result['tokens'], - 'text': text, - } - log(f" Prompt {i+1}: '{prompt}'") - log(f" Output: '{text[:80]}...' ({result['num_tokens']} tokens)") - - # Verify results - log("\n" + "-" * 50) - log("STEP 4: Verification") - log("-" * 50) - - all_passed = True - for prompt in test_prompts: - before = results_before[prompt] - after = results_after[prompt] - - # Check if outputs are different - outputs_differ = before['tokens'] != after['tokens'] - - # Check for expected answers - expected_answers = { - "What is 2+2?": ["4", "four"], - "What is the capital of France?": ["Paris", "paris"], - "What is 10*5?": ["50", "fifty"], - "Hello, my name is": [], # No specific expected answer - } - - expected = expected_answers.get(prompt, []) - has_correct_answer = any(ans.lower() in after['text'].lower() for ans in expected) if expected else True - - status = "PASS" if outputs_differ else "FAIL" - answer_status = "CORRECT" if has_correct_answer else "CHECK" - - log(f" '{prompt}':") - log(f" Before: '{before['text'][:50]}...'") - log(f" After: '{after['text'][:50]}...'") - log(f" Status: {status} (outputs differ: {outputs_differ}), Answer: {answer_status}") - - if not outputs_differ: - all_passed = False - - return all_passed + manager.sync_weights() + sampler.reset_prefix_cache() + sync_time = time.time() - sync_start + log(f" Weight sync completed in {sync_time:.2f}s") + # ── Sample AFTER sync (real weights → coherent) ─────────────────── + log("\n--- Sampling AFTER weight sync (real weights) ---") + text_after = do_sample("What is 2+2?") + log(f" Output: '{text_after[:100]}'") -def main(): - """Run weight sync test.""" - log("=" * 70) - log("TWINKLE WEIGHT SYNC TEST") - log(f"Model: {MODEL_ID}") - log(f"Configuration: Hybrid mode, {WORLD_SIZE} GPU, DP={WORLD_SIZE}, TP=1") + # ── Verification ────────────────────────────────────────────────── + log("\n" + "=" * 70) + log("VERIFICATION") log("=" * 70) - - results = [] - + + outputs_differ = text_before != text_after + log(f" Outputs differ after sync: {outputs_differ}") + + if outputs_differ: + log(" PASS: Weight sync verified — outputs changed after sync.") + if "4" in text_after.lower() or "four" in text_after.lower(): + log(" BONUS: Model correctly answered '2+2' question!") + else: + log(" FAIL: Outputs are identical — weight sync may have failed.") + sampler.shutdown() + + return outputs_differ + + +# ============================================================================= +# Main +# ============================================================================= + +def main(): + parser = argparse.ArgumentParser(description='Test STANDALONE weight synchronization') + parser.add_argument('--model-gpus', type=int, default=1, + help='Number of GPUs for model (training)') + parser.add_argument('--sampler-gpus', type=int, default=1, + help='Number of GPUs for sampler (inference)') + args = parser.parse_args() + + log(f"Starting standalone weight sync test...") + log(f" Model GPUs: {args.model_gpus}") + log(f" Sampler GPUs: {args.sampler_gpus}") + log(f" Model ID: {MODEL_ID}") + try: - passed = test_hybrid_weight_sync() - results.append((f'hybrid_weight_sync_{WORLD_SIZE}gpu', passed)) + success = test_standalone_weight_sync(args.model_gpus, args.sampler_gpus) except Exception as e: - log(f"Error in test: {e}") + log(f"\nTest failed with exception: {e}") import traceback traceback.print_exc() - results.append((f'hybrid_weight_sync_{WORLD_SIZE}gpu', False)) - - # Summary + success = False + log("\n" + "=" * 70) - log("FINAL SUMMARY") + log(f"RESULT: {'PASS' if success else 'FAIL'}") log("=" * 70) - for name, passed in results: - status = "PASSED" if passed else "FAILED" - log(f" {name}: {status}") - - passed_count = sum(1 for _, p in results if p) - log(f"\nTotal: {passed_count}/{len(results)} passed") - - if passed_count != len(results): - sys.exit(1) - - log("\nAll tests passed!") + + return 0 if success else 1 if __name__ == '__main__': - main() + sys.exit(main())