From ce5a4a20a608cd88a1632e18af0f06d004fb34d3 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Tue, 10 Mar 2026 17:23:34 +0800 Subject: [PATCH 01/56] wip --- cookbook/rl/gkd_off_policy.py | 163 +++++++++++++++++++ cookbook/rl/gkd_on_policy.py | 184 +++++++++++++++++++++ src/twinkle/loss/__init__.py | 3 + src/twinkle/loss/gkd.py | 232 +++++++++++++++++++++++++++ src/twinkle/preprocessor/__init__.py | 2 +- src/twinkle/preprocessor/llm.py | 26 ++- 6 files changed, 608 insertions(+), 2 deletions(-) create mode 100644 cookbook/rl/gkd_off_policy.py create mode 100644 cookbook/rl/gkd_on_policy.py create mode 100644 src/twinkle/loss/gkd.py diff --git a/cookbook/rl/gkd_off_policy.py b/cookbook/rl/gkd_off_policy.py new file mode 100644 index 00000000..4704e2ea --- /dev/null +++ b/cookbook/rl/gkd_off_policy.py @@ -0,0 +1,163 @@ +"""GKD Off-Policy Distillation via Ray. + +Off-policy knowledge distillation: the student learns to match the teacher's +token distribution on pre-existing reference responses from the dataset. + +Pipeline: + 1. DataLoader supplies full-text batches (prompt + reference answer). + 2. Teacher TransformersModel runs forward_only() to get frozen logits. + 3. Student TransformersModel runs forward_backward() with GKDLoss. + +Key difference from on-policy: + - No vLLM sampler needed (responses already in the dataset). + - Simpler GPU layout: all GPUs can go to the model group. + - Faster per-step (no generation latency), but less exploration. + +Architecture (Ray): + ┌─────────────────────────────────────────────────────────────────┐ + │ Driver (CPU) │ + │ dataloader ──► full-text batch (prompt + reference answer) │ + │ teacher_model.forward_only() ──► frozen teacher logits │ + │ student_model.forward_backward(teacher_logits=...) ──► GKD │ + └─────────────────────────────────────────────────────────────────┘ + │ + TransformersModel ×2 + student + teacher (all GPUs) + +Environment variables (all optional): + STUDENT_MODEL_ID – (default: ms://Qwen/Qwen2.5-1.5B-Instruct) + TEACHER_MODEL_ID – (default: ms://Qwen/Qwen2.5-7B-Instruct) + NUM_GPUS – total GPUs for both models (default: 4) + BATCH_SIZE – global batch size (default: 8) + MAX_STEPS – total optimisation steps (default: 200) + LR – learning rate (default: 1e-4) + GKD_BETA – JSD beta (0=fwd KL, 1=rev KL) (default: 0.5) + GKD_TEMPERATURE – distillation temperature (default: 1.0) + GKD_TOPK – top-k vocab reduction; 0=full (default: 0) +""" + +import os + +from peft import LoraConfig + +import twinkle +from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.loss import GKDLoss +from twinkle.model import TransformersModel +from twinkle.preprocessor import GSM8KFullProcessor + +logger = get_logger() + +# ── Configuration ───────────────────────────────────────────────────────────── +STUDENT_MODEL_ID = os.environ.get('STUDENT_MODEL_ID', 'ms://Qwen/Qwen2.5-1.5B-Instruct') +TEACHER_MODEL_ID = os.environ.get('TEACHER_MODEL_ID', 'ms://Qwen/Qwen2.5-7B-Instruct') + +NUM_GPUS = int(os.environ.get('NUM_GPUS', 4)) + +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) +MAX_STEPS = int(os.environ.get('MAX_STEPS', 200)) +LEARNING_RATE = float(os.environ.get('LR', 1e-4)) + +GKD_BETA = float(os.environ.get('GKD_BETA', 0.5)) +GKD_TEMPERATURE = float(os.environ.get('GKD_TEMPERATURE', 1.0)) +GKD_TOPK = int(os.environ.get('GKD_TOPK', 0)) + +ADAPTER_NAME = 'default' + + +# ── Dataset ─────────────────────────────────────────────────────────────────── + +def create_dataset(): + """Full-text dataset with prompt + reference answer for off-policy distillation.""" + dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train')) + dataset.set_template('Template', model_id=STUDENT_MODEL_ID, max_length=1024) + dataset.map(GSM8KFullProcessor()) + dataset.encode() + return dataset + + +# ── Training ────────────────────────────────────────────────────────────────── + +def main(): + device_groups = [ + DeviceGroup(name='model', ranks=list(range(NUM_GPUS)), device_type='cuda'), + ] + model_mesh = DeviceMesh.from_sizes(world_size=NUM_GPUS, dp_size=NUM_GPUS) + + twinkle.initialize( + mode='ray', + nproc_per_node=NUM_GPUS, + groups=device_groups, + lazy_collect=False, + ) + logger.info(get_device_placement()) + + # ── Student model (trainable) ────────────────────────────────────────────── + student_model = TransformersModel( + model_id=STUDENT_MODEL_ID, + device_mesh=model_mesh, + remote_group='model', + ) + student_model.add_adapter_to_model( + ADAPTER_NAME, + LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, target_modules='all-linear'), + gradient_accumulation_steps=1, + ) + student_model.set_optimizer('AdamW', lr=LEARNING_RATE) + student_model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=0) + student_model.set_loss(GKDLoss(beta=GKD_BETA, temperature=GKD_TEMPERATURE)) + student_model.set_template('Template', model_id=STUDENT_MODEL_ID) + + # ── Teacher model (frozen, for logits) ───────────────────────────────────── + teacher_model = TransformersModel( + model_id=TEACHER_MODEL_ID, + device_mesh=model_mesh, + remote_group='model', + ) + teacher_model.set_template('Template', model_id=TEACHER_MODEL_ID) + + # ── DataLoader (full-text: prompt + reference answer) ────────────────────── + dataloader = DataLoader( + dataset=create_dataset, + batch_size=BATCH_SIZE, + min_batch_size=BATCH_SIZE, + device_mesh=model_mesh, + remote_group='model', + ) + + topk = GKD_TOPK if GKD_TOPK > 0 else None + + logger.info(f'GKD Off-Policy | student={STUDENT_MODEL_ID} teacher={TEACHER_MODEL_ID}') + logger.info(f' beta={GKD_BETA} T={GKD_TEMPERATURE} topk={GKD_TOPK}') + + optim_step = 0 + for batch in dataloader: + if optim_step >= MAX_STEPS: + break + + input_data = batch if isinstance(batch, list) else [batch] + + # Teacher logits (frozen) + teacher_output = teacher_model.forward_only(inputs=input_data) + teacher_logits = teacher_output.get('logits') + + # Student forward + GKD backward + student_model.forward_backward(inputs=input_data, teacher_logits=teacher_logits, topk=topk) + student_model.clip_grad_and_step() + optim_step += 1 + + if optim_step % 10 == 0: + metric = student_model.calculate_metric(is_training=True) + logger.info(f'[Step {optim_step}/{MAX_STEPS}] {metric}') + + if optim_step % 50 == 0: + student_model.save(f'gkd-offpolicy-ckpt-{optim_step}') + + student_model.save('gkd-offpolicy-final') + logger.info('GKD off-policy training completed.') + + +if __name__ == '__main__': + main() diff --git a/cookbook/rl/gkd_on_policy.py b/cookbook/rl/gkd_on_policy.py new file mode 100644 index 00000000..09720a20 --- /dev/null +++ b/cookbook/rl/gkd_on_policy.py @@ -0,0 +1,184 @@ +"""GKD On-Policy Distillation via Ray. + +On-policy knowledge distillation: teacher vLLM generates fresh responses for +each prompt, then the student learns to match the teacher's token distribution. + +Pipeline: + 1. DataLoader supplies prompt-only batches. + 2. Teacher vLLM sampler generates completions on-the-fly. + 3. Teacher TransformersModel runs forward_only() to get frozen logits. + 4. Student TransformersModel runs forward_backward() with GKDLoss. + +Architecture (Ray): + ┌─────────────────────────────────────────────────────────────────┐ + │ Driver (CPU) │ + │ dataloader ──► prompt-only batch │ + │ teacher_sampler.sample() ──► on-policy completions │ + │ teacher_model.forward_only() ──► frozen teacher logits │ + │ student_model.forward_backward(teacher_logits=...) ──► GKD │ + └─────────────────────────────────────────────────────────────────┘ + │ │ │ + DataLoader vLLMSampler TransformersModel ×2 + (model GPUs) (sampler GPUs) student + teacher (model GPUs) + +Environment variables (all optional): + STUDENT_MODEL_ID – (default: ms://Qwen/Qwen2.5-1.5B-Instruct) + TEACHER_MODEL_ID – (default: ms://Qwen/Qwen2.5-7B-Instruct) + MODEL_GPUS – GPUs for student + teacher models (default: 4) + SAMPLER_GPUS – GPUs for teacher vLLM sampler (default: 4) + MAX_NEW_TOKENS – max completion tokens (default: 512) + BATCH_SIZE – global prompt-level batch size (default: 8) + MAX_STEPS – total optimisation steps (default: 200) + LR – learning rate (default: 1e-4) + GKD_BETA – JSD beta (0=fwd KL, 1=rev KL) (default: 0.5) + GKD_TEMPERATURE – distillation temperature (default: 1.0) + GKD_TOPK – top-k vocab reduction; 0=full (default: 0) +""" + +import os +from typing import List + +from peft import LoraConfig + +import twinkle +from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger +from twinkle.data_format import SamplingParams +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.loss import GKDLoss +from twinkle.model import TransformersModel +from twinkle.preprocessor import GSM8KProcessor +from twinkle.sampler import vLLMSampler +from twinkle.template import Template + +logger = get_logger() + +# ── Configuration ───────────────────────────────────────────────────────────── +STUDENT_MODEL_ID = os.environ.get('STUDENT_MODEL_ID', 'ms://Qwen/Qwen2.5-1.5B-Instruct') +TEACHER_MODEL_ID = os.environ.get('TEACHER_MODEL_ID', 'ms://Qwen/Qwen2.5-7B-Instruct') + +MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) +SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4)) +NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS + +MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 512)) +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) +MAX_STEPS = int(os.environ.get('MAX_STEPS', 200)) +LEARNING_RATE = float(os.environ.get('LR', 1e-4)) + +GKD_BETA = float(os.environ.get('GKD_BETA', 0.5)) +GKD_TEMPERATURE = float(os.environ.get('GKD_TEMPERATURE', 1.0)) +GKD_TOPK = int(os.environ.get('GKD_TOPK', 0)) + +ADAPTER_NAME = 'default' + + +# ── Dataset ─────────────────────────────────────────────────────────────────── + +def create_dataset(): + """Prompt-only dataset; teacher vLLM will generate completions on-policy.""" + dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train')) + dataset.set_template('Template', model_id=STUDENT_MODEL_ID, max_length=1024) + dataset.map(GSM8KProcessor()) + dataset.encode(add_generation_prompt=True) + return dataset + + +# ── Training ────────────────────────────────────────────────────────────────── + +def main(): + device_groups = [ + DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='cuda'), + DeviceGroup(name='sampler', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='cuda'), + ] + model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS) + sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS) + + twinkle.initialize( + mode='ray', + nproc_per_node=NUM_GPUS, + groups=device_groups, + lazy_collect=False, + ) + logger.info(get_device_placement()) + + # ── Student model (trainable) ────────────────────────────────────────────── + student_model = TransformersModel( + model_id=STUDENT_MODEL_ID, + device_mesh=model_mesh, + remote_group='model', + ) + student_model.add_adapter_to_model( + ADAPTER_NAME, + LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, target_modules='all-linear'), + gradient_accumulation_steps=1, + ) + student_model.set_optimizer('AdamW', lr=LEARNING_RATE) + student_model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=0) + student_model.set_loss(GKDLoss(beta=GKD_BETA, temperature=GKD_TEMPERATURE)) + student_model.set_template('Template', model_id=STUDENT_MODEL_ID) + + # ── Teacher model (frozen, for logits) ───────────────────────────────────── + teacher_model = TransformersModel( + model_id=TEACHER_MODEL_ID, + device_mesh=model_mesh, + remote_group='model', + ) + teacher_model.set_template('Template', model_id=TEACHER_MODEL_ID) + + # ── Teacher vLLM sampler (for on-policy generation) ──────────────────────── + teacher_sampler = vLLMSampler( + model_id=TEACHER_MODEL_ID, + engine_args={'gpu_memory_utilization': 0.85, 'max_model_len': 2048}, + device_mesh=sampler_mesh, + remote_group='sampler', + ) + teacher_sampler.set_template(Template, model_id=TEACHER_MODEL_ID) + + # ── DataLoader (prompt-only) ─────────────────────────────────────────────── + dataloader = DataLoader( + dataset=create_dataset, + batch_size=BATCH_SIZE, + min_batch_size=BATCH_SIZE, + device_mesh=model_mesh, + remote_group='model', + ) + + sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS, temperature=1.0) + topk = GKD_TOPK if GKD_TOPK > 0 else None + + logger.info(f'GKD On-Policy | student={STUDENT_MODEL_ID} teacher={TEACHER_MODEL_ID}') + logger.info(f' beta={GKD_BETA} T={GKD_TEMPERATURE} topk={GKD_TOPK}') + + optim_step = 0 + for batch in dataloader: + if optim_step >= MAX_STEPS: + break + + # Teacher vLLM generates completions + prompts: List = batch if isinstance(batch, list) else [batch] + sample_response = teacher_sampler.sample(prompts, sampling_params, num_samples=1) + input_data = [seq.new_input_feature for seq in sample_response.sequences] + + # Teacher logits (frozen) + teacher_output = teacher_model.forward_only(inputs=input_data) + teacher_logits = teacher_output.get('logits') + + # Student forward + GKD backward + student_model.forward_backward(inputs=input_data, teacher_logits=teacher_logits, topk=topk) + student_model.clip_grad_and_step() + optim_step += 1 + + if optim_step % 10 == 0: + metric = student_model.calculate_metric(is_training=True) + logger.info(f'[Step {optim_step}/{MAX_STEPS}] {metric}') + + if optim_step % 50 == 0: + student_model.save(f'gkd-onpolicy-ckpt-{optim_step}') + + student_model.save('gkd-onpolicy-final') + logger.info('GKD on-policy training completed.') + + +if __name__ == '__main__': + main() diff --git a/src/twinkle/loss/__init__.py b/src/twinkle/loss/__init__.py index e03681ae..9ee57a97 100644 --- a/src/twinkle/loss/__init__.py +++ b/src/twinkle/loss/__init__.py @@ -2,6 +2,7 @@ from .base import Loss from .chunked_cross_entropy import ChunkedCrossEntropyLoss from .cross_entropy import CrossEntropyLoss +from .gkd import GKDLoss from .grpo import BNPOLoss, CISPOLoss, DRGRPOLoss, GRPOLoss, GSPOLoss, SAPOLoss from .mse import MSELoss from .vocab_parallel_cross_entropy import VocabParallelCrossEntropyLoss @@ -11,6 +12,8 @@ 'cross_entropy': CrossEntropyLoss, 'chunked_cross_entropy': ChunkedCrossEntropyLoss, 'vocab_parallel_cross_entropy': VocabParallelCrossEntropyLoss, + # KD losses + 'gkd': GKDLoss, # RL losses 'grpo': GRPOLoss, 'gspo': GSPOLoss, diff --git a/src/twinkle/loss/gkd.py b/src/twinkle/loss/gkd.py new file mode 100644 index 00000000..e1a8b64e --- /dev/null +++ b/src/twinkle/loss/gkd.py @@ -0,0 +1,232 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from typing import TYPE_CHECKING, Optional + +from twinkle.data_format import LossOutput +from twinkle.loss.base import Loss + +if TYPE_CHECKING: + import torch + + +class GKDLoss(Loss): + """Generalized Knowledge Distillation (GKD) loss based on Jensen-Shannon Divergence. + + Implements the on-policy distillation objective from: + "On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes" + (https://arxiv.org/abs/2306.13649) + + The loss is a β-weighted mixture of two KL divergences: + JSD_β(S || T) = β · KL(T || M) + (1 - β) · KL(S || M) + where M = β · T + (1 - β) · S (mixture distribution) + + Special cases: + β = 0 → forward KL(S || T) (mean-seeking) + β = 1 → reverse KL(T || S) (mode-seeking) + β = 0.5 → symmetric JSD + + Args: + beta: Weight for teacher in the JSD mixture (default: 0.5). + temperature: Softmax temperature applied to logits before divergence (default: 1.0). + ignore_index: Token index to ignore in the loss mask (default: -100). + chunk_size: Number of valid tokens processed per chunk to reduce peak memory (default: 512). + """ + + def __init__( + self, + beta: float = 0.5, + temperature: float = 1.0, + ignore_index: int = -100, + chunk_size: int = 512, + **kwargs, + ): + self.beta = beta + self.temperature = temperature + self.ignore_index = ignore_index + self.chunk_size = chunk_size + + def __call__( + self, + inputs, + outputs, + *, + teacher_logits: Optional['torch.Tensor'] = None, + teacher_topk_logprobs: Optional['torch.Tensor'] = None, + teacher_topk_indices: Optional['torch.Tensor'] = None, + topk: Optional[int] = None, + **kwargs, + ) -> LossOutput: + """Compute GKD / JSD distillation loss. + + Args: + inputs: Dict containing 'labels' [batch, seq_len] with ignore_index for non-response tokens. + outputs: Dict containing 'logits' [batch, seq_len, vocab_size] from the student model. + teacher_logits: [batch, seq_len, vocab_size] full vocabulary logits from a local teacher. + Either teacher_logits or (teacher_topk_logprobs + teacher_topk_indices) + must be provided. + teacher_topk_logprobs: [batch, seq_len, topk] log-probs from a remote teacher API. + Returned by a vLLM-compatible /v1/completions prompt_logprobs call. + teacher_topk_indices: [batch, seq_len, topk] token indices corresponding to teacher_topk_logprobs. + topk: If set together with teacher_logits, only the top-k teacher tokens are used to + reduce vocabulary size before computing the JSD (memory-efficient local teacher mode). + + Returns: + LossOutput with scalar 'loss' averaged over valid (non-ignored) response tokens. + """ + assert teacher_logits is not None or ( + teacher_topk_logprobs is not None and teacher_topk_indices is not None + ), ( + 'Either teacher_logits or both teacher_topk_logprobs and teacher_topk_indices must be provided.' + ) + + labels = inputs['labels'] + student_logits = outputs['logits'] + + # Align seq dimension: some MLLMs return extra prefix logits + if student_logits.shape[1] != labels.shape[1]: + student_logits = student_logits[:, -labels.shape[1]:] + + # Shift labels: label[i] = next token predicted by logits[i] + # The last position wraps to label[0] via roll; since label[0] is -100 (prompt), + # it will be correctly excluded by the mask in _generalized_jsd_loss. + shifted_labels = labels.roll(shifts=-1, dims=1) + + loss = self._generalized_jsd_loss( + student_logits=student_logits, + teacher_logits=teacher_logits, + labels=shifted_labels, + beta=self.beta, + temperature=self.temperature, + chunk_size=self.chunk_size, + topk=topk, + teacher_topk_logprobs=teacher_topk_logprobs, + teacher_topk_indices=teacher_topk_indices, + ) + return LossOutput(loss=loss, num_tokens=0) + + @staticmethod + def _generalized_jsd_loss( + student_logits, + teacher_logits=None, + labels=None, + beta: float = 0.5, + temperature: float = 1.0, + chunk_size: int = 512, + topk: Optional[int] = None, + teacher_topk_logprobs=None, + teacher_topk_indices=None, + ): + """Core chunked JSD loss computation. + + Supports three teacher modes: + 1. Full-vocabulary local teacher (teacher_logits, topk=None) + 2. Top-k local teacher (teacher_logits, topk=k) + 3. Remote API teacher (teacher_topk_logprobs + teacher_topk_indices) + + The function processes valid tokens in chunks to keep peak GPU memory bounded. + + Args: + student_logits: [batch, seq_len, vocab_size] or [batch, seq_len, topk] after top-k reduction. + teacher_logits: [batch, seq_len, vocab_size] full vocabulary logits from local teacher. + labels: [batch, seq_len] shifted labels; positions where value == ignore_index are excluded. + beta: JSD mixture weight (0=forward KL, 1=reverse KL, 0.5=symmetric JSD). + temperature: Softmax temperature. + chunk_size: Tokens per chunk. + topk: If given, reduce local teacher to top-k tokens before computing JSD. + teacher_topk_logprobs: [batch, seq_len, topk] from remote API. + teacher_topk_indices: [batch, seq_len, topk] from remote API. + + Returns: + Scalar loss tensor. + """ + import torch + import torch.nn.functional as F + + # ── Top-k reduction ────────────────────────────────────────────────── + if teacher_topk_logprobs is not None and teacher_topk_indices is not None: + # Remote API teacher: teacher already provides top-k log-probs (T=1). + # Divide both student and teacher by temperature, then re-normalise. + s_scaled = student_logits / temperature + student_logits = torch.gather(s_scaled, dim=-1, index=teacher_topk_indices) + teacher_logits = teacher_topk_logprobs / temperature + del s_scaled + temperature = 1.0 + elif topk is not None and teacher_logits is not None: + # Local teacher top-k reduction + t_scaled = teacher_logits / temperature + s_scaled = student_logits / temperature + teacher_logits, topk_idx = torch.topk(t_scaled, k=topk, dim=-1) + student_logits = torch.gather(s_scaled, dim=-1, index=topk_idx) + del t_scaled, s_scaled, topk_idx + temperature = 1.0 + + # ── Temperature scaling ─────────────────────────────────────────────── + student_logits = student_logits / temperature + teacher_logits = teacher_logits / temperature + + # ── Mask valid (response) tokens ────────────────────────────────────── + if labels is not None: + mask = labels != -100 # ignore_index is always -100 per convention + # Vocab-size mismatch (e.g. Qwen2.5-VL-3B vs 7B): pad the smaller side + # so both distributions are defined over the same token set. + stu_dim = student_logits.shape[-1] + tea_dim = teacher_logits.shape[-1] + if stu_dim < tea_dim: + student_logits = F.pad(student_logits, (0, tea_dim - stu_dim)) + student_logits[..., stu_dim:] = teacher_logits[..., stu_dim:] + elif stu_dim > tea_dim: + teacher_logits = F.pad(teacher_logits, (0, stu_dim - tea_dim)) + teacher_logits[..., tea_dim:] = student_logits[..., tea_dim:] + student_logits = student_logits[mask] # [num_valid, vocab/topk] + teacher_logits = teacher_logits[mask] + num_valid = mask.sum() + else: + student_logits = student_logits.view(-1, student_logits.size(-1)) + teacher_logits = teacher_logits.view(-1, teacher_logits.size(-1)) + num_valid = student_logits.size(0) + + if num_valid == 0: + return student_logits.new_zeros(()) + + num_valid_int = int(num_valid) if isinstance(num_valid, int) else num_valid.item() + total_loss = student_logits.new_zeros(()) + + # Pre-compute log(beta) / log(1-beta) once for the mixture + if beta not in (0, 1): + beta_t = torch.tensor(beta, dtype=student_logits.dtype, device=student_logits.device) + log_beta = torch.log(beta_t) + log_1_minus_beta = torch.log1p(-beta_t) + else: + beta_t = log_beta = log_1_minus_beta = None + + # ── Chunked loss accumulation ───────────────────────────────────────── + for start in range(0, num_valid_int, chunk_size): + end = min(start + chunk_size, num_valid_int) + s_chunk = student_logits[start:end] + t_chunk = teacher_logits[start:end] + + s_log_probs = F.log_softmax(s_chunk, dim=-1) + t_log_probs = F.log_softmax(t_chunk, dim=-1) + del s_chunk, t_chunk + + if beta == 0: + # Forward KL: KL(S || T) + jsd_chunk = F.kl_div(s_log_probs, t_log_probs, reduction='none', log_target=True) + elif beta == 1: + # Reverse KL: KL(T || S) + jsd_chunk = F.kl_div(t_log_probs, s_log_probs, reduction='none', log_target=True) + else: + # Generalised JSD: β·KL(T||M) + (1-β)·KL(S||M) + mixture_log_probs = torch.logsumexp( + torch.stack([s_log_probs + log_1_minus_beta, t_log_probs + log_beta]), + dim=0, + ) + kl_teacher = F.kl_div(mixture_log_probs, t_log_probs, reduction='none', log_target=True) + kl_student = F.kl_div(mixture_log_probs, s_log_probs, reduction='none', log_target=True) + del mixture_log_probs + jsd_chunk = beta_t * kl_teacher + (1 - beta_t) * kl_student + del kl_teacher, kl_student + + total_loss = total_loss + jsd_chunk.sum() + del jsd_chunk, s_log_probs, t_log_probs + + return total_loss / num_valid diff --git a/src/twinkle/preprocessor/__init__.py b/src/twinkle/preprocessor/__init__.py index 1c19815e..7234a60a 100644 --- a/src/twinkle/preprocessor/__init__.py +++ b/src/twinkle/preprocessor/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .base import DataFilter, Preprocessor from .llm import (AlpacaProcessor, CompetitionMathGRPOProcessor, CompetitionMathProcessor, CountdownProcessor, - SelfCognitionProcessor) + GSM8KFullProcessor, GSM8KProcessor, SelfCognitionProcessor) diff --git a/src/twinkle/preprocessor/llm.py b/src/twinkle/preprocessor/llm.py index ddafb351..565f1661 100644 --- a/src/twinkle/preprocessor/llm.py +++ b/src/twinkle/preprocessor/llm.py @@ -129,7 +129,7 @@ class GSM8KProcessor(Preprocessor): def extract_ground_truth(self, answer_str: str) -> str: """Extract the number after '####' from GSM8K answer.""" - match = re.search(r'####\s*([\-\d,\.]+)', answer_str) + match = re.search(r'####\s*([\-\d,.]+)', answer_str) if match: return match.group(1).replace(',', '').strip() return '' @@ -153,3 +153,27 @@ def preprocess(self, row) -> Trajectory: messages=messages, user_data=[('ground_truth', ground_truth)], ) + + +class GSM8KFullProcessor(GSM8KProcessor): + """GSM8K preprocessor that includes the reference answer as the assistant message. + + Produces a full Trajectory (prompt + reference answer) suitable for + off-policy knowledge distillation: the student and teacher both see the + ground-truth response text, and labels cover the response tokens. + """ + + def preprocess(self, row) -> Trajectory: + question = row['question'] + answer = row.get('answer', '') + ground_truth = self.extract_ground_truth(answer) + + messages = [ + Message(role='system', content=self.system_prompt), + Message(role='user', content=question), + Message(role='assistant', content=answer), + ] + return Trajectory( + messages=messages, + user_data=[('ground_truth', ground_truth)], + ) From 89b96b4065190b152b82c1edcee5bd26dbdf2887 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Tue, 10 Mar 2026 22:19:40 +0800 Subject: [PATCH 02/56] wip --- cookbook/rl/gkd_on_policy.py | 10 ++--- src/twinkle/data_format/trajectory.py | 1 - src/twinkle/loss/__init__.py | 5 +-- src/twinkle/loss/cross_entropy.py | 42 ++++++++++++++----- src/twinkle/loss/gkd.py | 42 +++++++++---------- .../loss/vocab_parallel_cross_entropy.py | 20 --------- src/twinkle/model/megatron/megatron.py | 6 ++- .../model/transformers/strategy/accelerate.py | 2 +- .../model/transformers/transformers.py | 13 ++++-- 9 files changed, 73 insertions(+), 68 deletions(-) delete mode 100644 src/twinkle/loss/vocab_parallel_cross_entropy.py diff --git a/cookbook/rl/gkd_on_policy.py b/cookbook/rl/gkd_on_policy.py index 09720a20..9dbe8a6a 100644 --- a/cookbook/rl/gkd_on_policy.py +++ b/cookbook/rl/gkd_on_policy.py @@ -55,7 +55,7 @@ # ── Configuration ───────────────────────────────────────────────────────────── STUDENT_MODEL_ID = os.environ.get('STUDENT_MODEL_ID', 'ms://Qwen/Qwen2.5-1.5B-Instruct') -TEACHER_MODEL_ID = os.environ.get('TEACHER_MODEL_ID', 'ms://Qwen/Qwen2.5-7B-Instruct') +TEACHER_MODEL_ID = os.environ.get('TEACHER_MODEL_ID', 'ms://Qwen/Qwen3-4B') MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4)) @@ -98,7 +98,6 @@ def main(): mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, - lazy_collect=False, ) logger.info(get_device_placement()) @@ -156,16 +155,15 @@ def main(): break # Teacher vLLM generates completions - prompts: List = batch if isinstance(batch, list) else [batch] - sample_response = teacher_sampler.sample(prompts, sampling_params, num_samples=1) + sample_response = teacher_sampler.sample(batch, sampling_params, num_samples=1) input_data = [seq.new_input_feature for seq in sample_response.sequences] # Teacher logits (frozen) teacher_output = teacher_model.forward_only(inputs=input_data) - teacher_logits = teacher_output.get('logits') + teacher_output = teacher_output() # Student forward + GKD backward - student_model.forward_backward(inputs=input_data, teacher_logits=teacher_logits, topk=topk) + student_model.forward_backward(inputs=input_data, teacher_output=teacher_output, topk=topk) student_model.clip_grad_and_step() optim_step += 1 diff --git a/src/twinkle/data_format/trajectory.py b/src/twinkle/data_format/trajectory.py index a4f694cb..c7742d75 100644 --- a/src/twinkle/data_format/trajectory.py +++ b/src/twinkle/data_format/trajectory.py @@ -15,5 +15,4 @@ class Trajectory(TypedDict, total=False): messages: List[Message] extend_message: List[Tuple[str, List[Message]]] tools: List[Tool] - advantages: float user_data: List[Tuple[str, Any]] diff --git a/src/twinkle/loss/__init__.py b/src/twinkle/loss/__init__.py index 9ee57a97..65303dfd 100644 --- a/src/twinkle/loss/__init__.py +++ b/src/twinkle/loss/__init__.py @@ -5,13 +5,12 @@ from .gkd import GKDLoss from .grpo import BNPOLoss, CISPOLoss, DRGRPOLoss, GRPOLoss, GSPOLoss, SAPOLoss from .mse import MSELoss -from .vocab_parallel_cross_entropy import VocabParallelCrossEntropyLoss +from .cross_entropy import CrossEntropyLoss torch_loss_mapping = { 'mse': MSELoss, - 'cross_entropy': CrossEntropyLoss, 'chunked_cross_entropy': ChunkedCrossEntropyLoss, - 'vocab_parallel_cross_entropy': VocabParallelCrossEntropyLoss, + 'cross_entropy': CrossEntropyLoss, # KD losses 'gkd': GKDLoss, # RL losses diff --git a/src/twinkle/loss/cross_entropy.py b/src/twinkle/loss/cross_entropy.py index 12851d45..06bf791a 100644 --- a/src/twinkle/loss/cross_entropy.py +++ b/src/twinkle/loss/cross_entropy.py @@ -1,20 +1,40 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from twinkle.data_format import LossOutput -from twinkle.utils import selective_log_softmax +from ..data_format import LossOutput from .base import Loss class CrossEntropyLoss(Loss): + """Calculate CE from logps""" - def __init__(self, **kwargs): - self.reduction = kwargs.get('reduction', 'mean') + def __init__(self, ignore_index: int = -100, reduction='mean', **kwargs): + super().__init__() + self.ignore_index = ignore_index + self.reduction = reduction def __call__(self, inputs, outputs, **kwargs): - import torch - logits = outputs['logits'].view(-1, outputs['logits'].shape[-1]) - labels = inputs['labels'].view(-1) - loss = torch.nn.CrossEntropyLoss(reduction=self.reduction)(logits, labels) - if self.reduction != 'sum': - return LossOutput(loss=loss, num_tokens=0) + labels = inputs['labels'] + logps = outputs.get('logps') + logits = outputs.get('logits') + + if logps is not None: + loss_mask = (labels != self.ignore_index).float() + if self.reduction != 'sum': + return LossOutput( + loss=(-logps * loss_mask).sum() / loss_mask.sum().clamp(min=1), + num_tokens=0, + ) + else: + return LossOutput( + loss=(-logps * loss_mask).sum(), + num_tokens=loss_mask.sum().clamp(min=1), + ) else: - return LossOutput(loss=loss, num_tokens=(labels != -100).sum()) + import torch + assert logits is not None + logits = logits.view(-1, logits.shape[-1]) + labels = labels.view(-1) + loss = torch.nn.CrossEntropyLoss(reduction=self.reduction)(logits, labels) + if self.reduction != 'sum': + return LossOutput(loss=loss, num_tokens=0) + else: + return LossOutput(loss=loss, num_tokens=(labels != self.ignore_index).sum()) diff --git a/src/twinkle/loss/gkd.py b/src/twinkle/loss/gkd.py index e1a8b64e..027b1b74 100644 --- a/src/twinkle/loss/gkd.py +++ b/src/twinkle/loss/gkd.py @@ -49,9 +49,7 @@ def __call__( inputs, outputs, *, - teacher_logits: Optional['torch.Tensor'] = None, - teacher_topk_logprobs: Optional['torch.Tensor'] = None, - teacher_topk_indices: Optional['torch.Tensor'] = None, + teacher_output: Optional['torch.Tensor'] = None, topk: Optional[int] = None, **kwargs, ) -> LossOutput: @@ -60,18 +58,23 @@ def __call__( Args: inputs: Dict containing 'labels' [batch, seq_len] with ignore_index for non-response tokens. outputs: Dict containing 'logits' [batch, seq_len, vocab_size] from the student model. - teacher_logits: [batch, seq_len, vocab_size] full vocabulary logits from a local teacher. + teacher_output: A dict contains: + teacher_logits: [batch, seq_len, vocab_size] full vocabulary logits from a local teacher. Either teacher_logits or (teacher_topk_logprobs + teacher_topk_indices) must be provided. - teacher_topk_logprobs: [batch, seq_len, topk] log-probs from a remote teacher API. - Returned by a vLLM-compatible /v1/completions prompt_logprobs call. - teacher_topk_indices: [batch, seq_len, topk] token indices corresponding to teacher_topk_logprobs. + teacher_topk_logprobs: [batch, seq_len, topk] log-probs from a remote teacher API. + Returned by a vLLM-compatible /v1/completions prompt_logprobs call. + teacher_topk_indices: [batch, seq_len, topk] token indices corresponding to teacher_topk_logprobs. topk: If set together with teacher_logits, only the top-k teacher tokens are used to reduce vocabulary size before computing the JSD (memory-efficient local teacher mode). Returns: LossOutput with scalar 'loss' averaged over valid (non-ignored) response tokens. """ + breakpoint() + teacher_logits = teacher_output.get('logits') + teacher_topk_logprobs = teacher_output.get('teacher_topk_logprobs') + teacher_topk_indices = teacher_output.get('teacher_topk_indices') assert teacher_logits is not None or ( teacher_topk_logprobs is not None and teacher_topk_indices is not None ), ( @@ -142,27 +145,22 @@ def _generalized_jsd_loss( import torch.nn.functional as F # ── Top-k reduction ────────────────────────────────────────────────── + # Top-k mode: gather/topk first to get small [*, k] tensors, then scale in-place if teacher_topk_logprobs is not None and teacher_topk_indices is not None: # Remote API teacher: teacher already provides top-k log-probs (T=1). - # Divide both student and teacher by temperature, then re-normalise. - s_scaled = student_logits / temperature - student_logits = torch.gather(s_scaled, dim=-1, index=teacher_topk_indices) + # Gather student logits at teacher's top-k indices, then scale in-place. + student_logits = torch.gather(student_logits, dim=-1, index=teacher_topk_indices) + student_logits.div_(temperature) teacher_logits = teacher_topk_logprobs / temperature - del s_scaled temperature = 1.0 elif topk is not None and teacher_logits is not None: - # Local teacher top-k reduction - t_scaled = teacher_logits / temperature - s_scaled = student_logits / temperature - teacher_logits, topk_idx = torch.topk(t_scaled, k=topk, dim=-1) - student_logits = torch.gather(s_scaled, dim=-1, index=topk_idx) - del t_scaled, s_scaled, topk_idx + # Local teacher: select top-k from teacher, gather corresponding student logits + teacher_logits, topk_idx = torch.topk(teacher_logits, k=topk, dim=-1) + teacher_logits.div_(temperature) + student_logits = torch.gather(student_logits, dim=-1, index=topk_idx) + student_logits.div_(temperature) temperature = 1.0 - # ── Temperature scaling ─────────────────────────────────────────────── - student_logits = student_logits / temperature - teacher_logits = teacher_logits / temperature - # ── Mask valid (response) tokens ────────────────────────────────────── if labels is not None: mask = labels != -100 # ignore_index is always -100 per convention @@ -183,6 +181,8 @@ def _generalized_jsd_loss( student_logits = student_logits.view(-1, student_logits.size(-1)) teacher_logits = teacher_logits.view(-1, teacher_logits.size(-1)) num_valid = student_logits.size(0) + student_logits.div_(temperature) + teacher_logits.div_(temperature) if num_valid == 0: return student_logits.new_zeros(()) diff --git a/src/twinkle/loss/vocab_parallel_cross_entropy.py b/src/twinkle/loss/vocab_parallel_cross_entropy.py deleted file mode 100644 index 166e843f..00000000 --- a/src/twinkle/loss/vocab_parallel_cross_entropy.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -from ..data_format import LossOutput -from .base import Loss - - -class VocabParallelCrossEntropyLoss(Loss): - - def __init__(self, ignore_index: int = -100): - super().__init__() - self.ignore_index = ignore_index - - def __call__(self, inputs, outputs, **kwargs): - labels = inputs['labels'] - logps = outputs.get('logps') - - loss_mask = (labels != self.ignore_index).float() - return LossOutput( - loss=(-logps * loss_mask).sum(), - num_tokens=loss_mask.sum().clamp(min=1), - ) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 68e68f5a..29c67d73 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -26,7 +26,7 @@ from twinkle.checkpoint_engine.mixin import CheckpointEngineMixin from twinkle.data_format import InputFeature, ModelOutput, Trajectory from twinkle.hub import HubOperation -from twinkle.loss import Loss, VocabParallelCrossEntropyLoss +from twinkle.loss import Loss, CrossEntropyLoss from twinkle.metric import LossMetric, Metric, TrainMetric from twinkle.model.base import TwinkleModel from twinkle.patch import Patch, apply_patch @@ -238,7 +238,7 @@ def __init__( def _construct_default_optimizer_group(self): return MegatronOptimizerGroup( - loss_instance=VocabParallelCrossEntropyLoss(), + loss_instance=CrossEntropyLoss(), template=Template(self.tokenizer_id), processor=InputProcessor(self.device_mesh, framework='megatron'), _device_mesh=self.device_mesh, @@ -398,6 +398,7 @@ def forward_backward(self, from megatron.core.pipeline_parallel import get_forward_backward_func adapter_name = kwargs.pop('adapter_name', self._get_default_group()) + temperature = float(kwargs.pop('temperature', 1.0)) forward_only = kwargs.pop('forward_only', False) optimizer_config = self.optimizer_group[adapter_name] loss_instance = self.optimizer_group[adapter_name].loss_instance @@ -485,6 +486,7 @@ def forward_step_func(data_iterator, model): loss_mask = (labels != -100).bool() masked_labels = labels.clone() masked_labels[~loss_mask] = 0 + output_tensor.div_(temperature) logps = selective_log_softmax(output_tensor, masked_labels) if cp_size > 1: logps = self._postprocess_tensor_cp(logps) diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py index d0e76378..c3388bcc 100644 --- a/src/twinkle/model/transformers/strategy/accelerate.py +++ b/src/twinkle/model/transformers/strategy/accelerate.py @@ -80,7 +80,7 @@ def _fsdp_config_from_device_mesh(self, device_mesh: DeviceMesh, fsdp_config: Di fsdp_size = device_mesh.get_dim_size('fsdp') if device_mesh.has_dim('fsdp') else 1 dp_size = device_mesh.get_dim_size('dp') if device_mesh.has_dim('dp') else 1 - if fsdp_size == 1 and dp_size == 1: + if fsdp_size == 1: return None fsdp_config = fsdp_config or {} diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 0d92bd24..1bc2187b 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -355,6 +355,7 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajec The output of the model forward. """ adapter_name = kwargs.pop('adapter_name', self._get_default_group()) + temperature = float(kwargs.pop('temperature', 1.0)) optimizer_config = self.optimizer_group[adapter_name] self._lazy_wrap_model() if not inputs: @@ -386,10 +387,12 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajec loss_mask = (labels != -100).bool() masked_labels = labels.clone() masked_labels[~loss_mask] = 0 - outputs['logps'] = selective_log_softmax(outputs['logits'], masked_labels) + logits = outputs['logits'] + logits.div_(temperature) + outputs['logps'] = selective_log_softmax(logits, masked_labels) return outputs - @remote_function(dispatch='slice_dp', collect='flatten') + @remote_function(dispatch='slice_dp') def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajectory]], **kwargs): """Call forward function without grad and record the inputs and outputs. @@ -401,6 +404,7 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T The output of the model forward. """ adapter_name = kwargs.pop('adapter_name', self._get_default_group()) + temperature = float(kwargs.pop('temperature', 1.0)) optimizer_config = self.optimizer_group[adapter_name] self._lazy_wrap_model() if not inputs: @@ -433,7 +437,10 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T loss_mask = (labels != -100).bool() masked_labels = labels.clone() masked_labels[~loss_mask] = 0 - outputs['logps'] = selective_log_softmax(outputs['logits'], masked_labels) + logits = outputs['logits'] + logits.div_(temperature) + outputs['logps'] = selective_log_softmax(logits, masked_labels) + outputs.pop('past_key_values', None) return outputs @remote_function(collect='mean') From 85c5afb7bd09ee0d6efa88ef8c1fe5776d9009a6 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Tue, 10 Mar 2026 23:32:02 +0800 Subject: [PATCH 03/56] wip --- src/twinkle/infra/__init__.py | 1 + src/twinkle/infra/collectors.py | 75 +++++++++++++++++++ src/twinkle/model/megatron/megatron.py | 2 +- .../model/transformers/transformers.py | 7 +- src/twinkle/utils/__init__.py | 2 +- src/twinkle/utils/torch_utils.py | 33 +++++++- 6 files changed, 114 insertions(+), 6 deletions(-) create mode 100644 src/twinkle/infra/collectors.py diff --git a/src/twinkle/infra/__init__.py b/src/twinkle/infra/__init__.py index 9c37b367..9afd5d52 100644 --- a/src/twinkle/infra/__init__.py +++ b/src/twinkle/infra/__init__.py @@ -6,6 +6,7 @@ from typing import Any, Callable, List, Literal, Optional, TypeVar, Union from twinkle.utils import DeviceGroup, DeviceMesh, Platform, check_unsafe, framework_util, get_logger, requires +from .collectors import collect_tensor_dict logger = get_logger() diff --git a/src/twinkle/infra/collectors.py b/src/twinkle/infra/collectors.py new file mode 100644 index 00000000..a29583e9 --- /dev/null +++ b/src/twinkle/infra/collectors.py @@ -0,0 +1,75 @@ +from typing import List, Dict, Any +import torch + + +def collect_tensor_dict(outputs: List[Dict[str, Any]]) -> Dict[str, Any]: + if not outputs: + return {} + + if len(outputs) == 1: + return outputs[0] + + all_keys = set() + for d in outputs: + all_keys.update(d.keys()) + + result = {} + for key in all_keys: + values = [d[key] for d in outputs if key in d] + + if not values: + continue + + first_value = values[0] + + if isinstance(first_value, list): + merged = [] + for v in values: + if isinstance(v, list): + merged.extend(v) + else: + merged.append(v) + result[key] = merged + + elif isinstance(first_value, torch.Tensor): + result[key] = _pad_and_stack_tensors(values) + + elif isinstance(first_value, dict): + result[key] = collect_tensor_dict(values) + + else: + result[key] = values + + return result + + +def _pad_and_stack_tensors(tensors: List[torch.Tensor], pad_value: float = 0) -> torch.Tensor: + if not tensors: + raise ValueError("Empty tensor list") + + if len(tensors) == 1: + return tensors[0].unsqueeze(0) + + max_ndim = max(t.ndim for t in tensors) + expanded_tensors = [] + for t in tensors: + while t.ndim < max_ndim: + t = t.unsqueeze(0) + expanded_tensors.append(t) + + max_shape = [] + for dim in range(max_ndim): + max_shape.append(max(t.shape[dim] for t in expanded_tensors)) + + padded_tensors = [] + for t in expanded_tensors: + if list(t.shape) == max_shape: + padded_tensors.append(t) + else: + pad_params = [] + for dim in range(max_ndim - 1, -1, -1): + pad_params.extend([0, max_shape[dim] - t.shape[dim]]) + padded = torch.nn.functional.pad(t, pad_params, value=pad_value) + padded_tensors.append(padded) + + return torch.stack(padded_tensors, dim=0) \ No newline at end of file diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 29c67d73..c796a6e4 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -364,7 +364,7 @@ def calculate_loss(self, **kwargs): def backward(self, **kwargs): raise NotImplementedError('Megatron only supports `forward_backward` and `forward_only`') - @remote_function(dispatch='slice_dp', collect='mean', sync=True) + @remote_function(dispatch='slice_dp', collect='last_pp', sync=True) def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 1bc2187b..4f39a176 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -26,6 +26,7 @@ from twinkle.checkpoint_engine.mixin import CheckpointEngineMixin from twinkle.data_format import InputFeature, ModelOutput, Trajectory from twinkle.hub import HubOperation +from twinkle.infra import collect_tensor_dict from twinkle.loss import CrossEntropyLoss, Loss from twinkle.metric import Accuracy, LossMetric, Metric, TrainMetric from twinkle.model.base import TwinkleModel @@ -343,7 +344,7 @@ def _construct_default_optimizer_group(self): _device_mesh=self.device_mesh, ) - @remote_function() + @remote_function(dispatch='slice_dp', collect=collect_tensor_dict) def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajectory]], **kwargs): """Call forward function and record the inputs and outputs. @@ -392,7 +393,7 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajec outputs['logps'] = selective_log_softmax(logits, masked_labels) return outputs - @remote_function(dispatch='slice_dp') + @remote_function(dispatch='slice_dp', collect=collect_tensor_dict) def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajectory]], **kwargs): """Call forward function without grad and record the inputs and outputs. @@ -502,7 +503,7 @@ def backward(self, **kwargs): optimizer_config.cur_step += 1 optimizer_config.loss_value = None - @remote_function(dispatch='slice_dp', collect='mean') + @remote_function(dispatch='slice_dp', collect='flatten') def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs): """Do forward, calculate loss, and backward. diff --git a/src/twinkle/utils/__init__.py b/src/twinkle/utils/__init__.py index 1b018773..0218bf7d 100644 --- a/src/twinkle/utils/__init__.py +++ b/src/twinkle/utils/__init__.py @@ -10,7 +10,7 @@ from .parallel import processing_lock from .platforms import GPU, NPU, Platform, ensure_hccl_socket_env, ensure_npu_backend from .safetensors import LazyTensor, SafetensorLazyLoader, StreamingSafetensorSaver -from .torch_utils import pad_sequence_to_length, selective_log_softmax, stateless_init_process_group, to_device +from .torch_utils import pad_sequence_to_length, selective_log_softmax, stateless_init_process_group, to_device, split_tensor_to_list from .transformers_utils import find_all_linears, find_layers, get_modules_to_not_convert from .unsafe import check_unsafe, trust_remote_code from .utils import copy_files_by_pattern, deep_getattr diff --git a/src/twinkle/utils/torch_utils.py b/src/twinkle/utils/torch_utils.py index 35c636d8..1ec1786c 100644 --- a/src/twinkle/utils/torch_utils.py +++ b/src/twinkle/utils/torch_utils.py @@ -1,6 +1,6 @@ import socket from datetime import timedelta -from typing import TYPE_CHECKING, Any, Mapping, Union +from typing import TYPE_CHECKING, Any, Mapping, Union, List, Dict from .network import is_valid_ipv6_address @@ -192,3 +192,34 @@ def stateless_init_process_group( communicator = Communicator(pg, device=device) return communicator + + +def split_tensor_to_list(data: Union[torch.Tensor, Dict[str, Any]]) -> Union[List[torch.Tensor], List[Dict[str, Any]]]: + if isinstance(data, torch.Tensor): + return list(data) + + if isinstance(data, dict): + if not data: + return [{}] + + batch_size = None + for v in data.values(): + if isinstance(v, torch.Tensor): + batch_size = v.size(0) + break + + if batch_size is None: + return [data] + + result = [] + for i in range(batch_size): + item = {} + for k, v in data.items(): + if isinstance(v, torch.Tensor): + item[k] = v[i] + else: + item[k] = v + result.append(item) + return result + + raise TypeError(f"Unsupported type: {type(data)}, expected Tensor or Dict") From 75d0377e67b685310140c76fda506cb1f4e2aa99 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Tue, 10 Mar 2026 23:52:29 +0800 Subject: [PATCH 04/56] fix --- src/twinkle/infra/collectors.py | 4 +-- .../model/transformers/transformers.py | 3 +- src/twinkle/utils/__init__.py | 2 +- src/twinkle/utils/torch_utils.py | 31 ------------------- 4 files changed, 5 insertions(+), 35 deletions(-) diff --git a/src/twinkle/infra/collectors.py b/src/twinkle/infra/collectors.py index a29583e9..456258fd 100644 --- a/src/twinkle/infra/collectors.py +++ b/src/twinkle/infra/collectors.py @@ -2,7 +2,7 @@ import torch -def collect_tensor_dict(outputs: List[Dict[str, Any]]) -> Dict[str, Any]: +def collect_tensor_dict(outputs: List[Dict[str, Any]], device_mesh) -> Dict[str, Any]: if not outputs: return {} @@ -72,4 +72,4 @@ def _pad_and_stack_tensors(tensors: List[torch.Tensor], pad_value: float = 0) -> padded = torch.nn.functional.pad(t, pad_params, value=pad_value) padded_tensors.append(padded) - return torch.stack(padded_tensors, dim=0) \ No newline at end of file + return torch.cat(padded_tensors, dim=0) \ No newline at end of file diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 4f39a176..3dccfd13 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -391,6 +391,7 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajec logits = outputs['logits'] logits.div_(temperature) outputs['logps'] = selective_log_softmax(logits, masked_labels) + outputs['past_key_values'] = None return outputs @remote_function(dispatch='slice_dp', collect=collect_tensor_dict) @@ -441,7 +442,7 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T logits = outputs['logits'] logits.div_(temperature) outputs['logps'] = selective_log_softmax(logits, masked_labels) - outputs.pop('past_key_values', None) + outputs['past_key_values'] = None return outputs @remote_function(collect='mean') diff --git a/src/twinkle/utils/__init__.py b/src/twinkle/utils/__init__.py index 0218bf7d..1b018773 100644 --- a/src/twinkle/utils/__init__.py +++ b/src/twinkle/utils/__init__.py @@ -10,7 +10,7 @@ from .parallel import processing_lock from .platforms import GPU, NPU, Platform, ensure_hccl_socket_env, ensure_npu_backend from .safetensors import LazyTensor, SafetensorLazyLoader, StreamingSafetensorSaver -from .torch_utils import pad_sequence_to_length, selective_log_softmax, stateless_init_process_group, to_device, split_tensor_to_list +from .torch_utils import pad_sequence_to_length, selective_log_softmax, stateless_init_process_group, to_device from .transformers_utils import find_all_linears, find_layers, get_modules_to_not_convert from .unsafe import check_unsafe, trust_remote_code from .utils import copy_files_by_pattern, deep_getattr diff --git a/src/twinkle/utils/torch_utils.py b/src/twinkle/utils/torch_utils.py index 1ec1786c..7097f946 100644 --- a/src/twinkle/utils/torch_utils.py +++ b/src/twinkle/utils/torch_utils.py @@ -192,34 +192,3 @@ def stateless_init_process_group( communicator = Communicator(pg, device=device) return communicator - - -def split_tensor_to_list(data: Union[torch.Tensor, Dict[str, Any]]) -> Union[List[torch.Tensor], List[Dict[str, Any]]]: - if isinstance(data, torch.Tensor): - return list(data) - - if isinstance(data, dict): - if not data: - return [{}] - - batch_size = None - for v in data.values(): - if isinstance(v, torch.Tensor): - batch_size = v.size(0) - break - - if batch_size is None: - return [data] - - result = [] - for i in range(batch_size): - item = {} - for k, v in data.items(): - if isinstance(v, torch.Tensor): - item[k] = v[i] - else: - item[k] = v - result.append(item) - return result - - raise TypeError(f"Unsupported type: {type(data)}, expected Tensor or Dict") From 79c22fb1ee90096003f7e1850c65462d6c792931 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Tue, 10 Mar 2026 23:57:57 +0800 Subject: [PATCH 05/56] fix --- src/twinkle/infra/__init__.py | 3 ++- src/twinkle/infra/collectors.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/twinkle/infra/__init__.py b/src/twinkle/infra/__init__.py index 9afd5d52..4bd5b3a8 100644 --- a/src/twinkle/infra/__init__.py +++ b/src/twinkle/infra/__init__.py @@ -344,7 +344,8 @@ def dispatch_func(arg, n): length = len(workers) def dispatch_func(arg, n): - if isinstance(arg, list): + import torch + if isinstance(arg, list) or isinstance(arg, torch.Tensor): _args = [] for i in range(n): _args.append(arg[device_mesh.get_slice(len(arg), device_mesh.get_data_rank_from_global_rank(i))]) diff --git a/src/twinkle/infra/collectors.py b/src/twinkle/infra/collectors.py index 456258fd..749d664c 100644 --- a/src/twinkle/infra/collectors.py +++ b/src/twinkle/infra/collectors.py @@ -2,7 +2,7 @@ import torch -def collect_tensor_dict(outputs: List[Dict[str, Any]], device_mesh) -> Dict[str, Any]: +def collect_tensor_dict(outputs: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]: if not outputs: return {} From 1082035e47314d1ca2dddc89832d8497234274ea Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 11 Mar 2026 17:48:50 +0800 Subject: [PATCH 06/56] wip --- src/twinkle/infra/__init__.py | 7 +++++++ src/twinkle/infra/collectors.py | 2 +- src/twinkle/loss/gkd.py | 3 ++- src/twinkle/model/transformers/transformers.py | 13 ++++++++----- 4 files changed, 18 insertions(+), 7 deletions(-) diff --git a/src/twinkle/infra/__init__.py b/src/twinkle/infra/__init__.py index 4bd5b3a8..f25f0fa3 100644 --- a/src/twinkle/infra/__init__.py +++ b/src/twinkle/infra/__init__.py @@ -350,6 +350,13 @@ def dispatch_func(arg, n): for i in range(n): _args.append(arg[device_mesh.get_slice(len(arg), device_mesh.get_data_rank_from_global_rank(i))]) return _args + elif isinstance(arg, dict): + _args = [{} for _ in range(n)] + for key in arg.keys(): + value = arg[key] + for i, v in enumerate(dispatch_func(value, n)): + _args[i][key] = v + return _args else: return [arg] * n diff --git a/src/twinkle/infra/collectors.py b/src/twinkle/infra/collectors.py index 749d664c..643a9c66 100644 --- a/src/twinkle/infra/collectors.py +++ b/src/twinkle/infra/collectors.py @@ -17,7 +17,7 @@ def collect_tensor_dict(outputs: List[Dict[str, Any]], **kwargs) -> Dict[str, An for key in all_keys: values = [d[key] for d in outputs if key in d] - if not values: + if not values or all([v is None for v in values]): continue first_value = values[0] diff --git a/src/twinkle/loss/gkd.py b/src/twinkle/loss/gkd.py index 027b1b74..25878739 100644 --- a/src/twinkle/loss/gkd.py +++ b/src/twinkle/loss/gkd.py @@ -71,7 +71,6 @@ def __call__( Returns: LossOutput with scalar 'loss' averaged over valid (non-ignored) response tokens. """ - breakpoint() teacher_logits = teacher_output.get('logits') teacher_topk_logprobs = teacher_output.get('teacher_topk_logprobs') teacher_topk_indices = teacher_output.get('teacher_topk_indices') @@ -83,6 +82,8 @@ def __call__( labels = inputs['labels'] student_logits = outputs['logits'] + if teacher_logits.shape[1] > student_logits.shape[1]: + teacher_logits = teacher_logits[:, :student_logits.shape[1]] # Align seq dimension: some MLLMs return extra prefix logits if student_logits.shape[1] != labels.shape[1]: diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 3dccfd13..4619bcec 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -105,11 +105,14 @@ def _ensure_dp_group(self): self._dp_group = self._device_mesh.create_process_group(dims) def _get_lr(self): - _lrs = [] - _default_lr = self.optimizer.defaults.get('lr') - for param_group in self.optimizer.param_groups: - _lrs.append(param_group.get('lr', _default_lr)) - return _lrs + if self.optimizer is not None: + _lrs = [] + _default_lr = self.optimizer.defaults.get('lr') + for param_group in self.optimizer.param_groups: + _lrs.append(param_group.get('lr', _default_lr)) + return _lrs + else: + return [] def accumulate_metrics(self, is_training): self._ensure_dp_group() From eb6b5be407777ac41278a1d09e6799362601921a Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 11 Mar 2026 17:50:10 +0800 Subject: [PATCH 07/56] fix --- src/twinkle/infra/collectors.py | 10 +++++++--- src/twinkle/loss/gkd.py | 7 +++---- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/twinkle/infra/collectors.py b/src/twinkle/infra/collectors.py index 749d664c..8e5794c2 100644 --- a/src/twinkle/infra/collectors.py +++ b/src/twinkle/infra/collectors.py @@ -1,5 +1,7 @@ -from typing import List, Dict, Any -import torch +from typing import List, Dict, Any, TYPE_CHECKING + +if TYPE_CHECKING: + import torch def collect_tensor_dict(outputs: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]: @@ -13,6 +15,7 @@ def collect_tensor_dict(outputs: List[Dict[str, Any]], **kwargs) -> Dict[str, An for d in outputs: all_keys.update(d.keys()) + import torch result = {} for key in all_keys: values = [d[key] for d in outputs if key in d] @@ -43,7 +46,8 @@ def collect_tensor_dict(outputs: List[Dict[str, Any]], **kwargs) -> Dict[str, An return result -def _pad_and_stack_tensors(tensors: List[torch.Tensor], pad_value: float = 0) -> torch.Tensor: +def _pad_and_stack_tensors(tensors: List['torch.Tensor'], pad_value: float = 0) -> 'torch.Tensor': + import torch if not tensors: raise ValueError("Empty tensor list") diff --git a/src/twinkle/loss/gkd.py b/src/twinkle/loss/gkd.py index 027b1b74..9d1d1ab5 100644 --- a/src/twinkle/loss/gkd.py +++ b/src/twinkle/loss/gkd.py @@ -71,14 +71,13 @@ def __call__( Returns: LossOutput with scalar 'loss' averaged over valid (non-ignored) response tokens. """ - breakpoint() teacher_logits = teacher_output.get('logits') - teacher_topk_logprobs = teacher_output.get('teacher_topk_logprobs') - teacher_topk_indices = teacher_output.get('teacher_topk_indices') + teacher_topk_logprobs = teacher_output.get('topk_logprobs') + teacher_topk_indices = teacher_output.get('topk_indices') assert teacher_logits is not None or ( teacher_topk_logprobs is not None and teacher_topk_indices is not None ), ( - 'Either teacher_logits or both teacher_topk_logprobs and teacher_topk_indices must be provided.' + 'Either logits or both topk_logprobs and topk_indices must be provided.' ) labels = inputs['labels'] From 09c3c0ffff4e75b589bad0ad2b5ac87c7bd1fb20 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 11 Mar 2026 18:26:21 +0800 Subject: [PATCH 08/56] wip --- cookbook/transformers/fsdp2.py | 13 +++++++------ .../Usage Guide/Introduction-with-Qwen3.5.md | 4 ++-- ...\234\200\344\275\263\345\256\236\350\267\265.md" | 4 ++-- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py index ca37d724..e0f67537 100644 --- a/cookbook/transformers/fsdp2.py +++ b/cookbook/transformers/fsdp2.py @@ -20,7 +20,7 @@ def eval(model): # 100 Samples dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100))) - dataset.set_template('Template', model_id='ms://Qwen/Qwen3-4B') + dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B') dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) dataset.encode() dataloader = DataLoader(dataset=dataset, batch_size=8) @@ -35,7 +35,7 @@ def train(): # 1000 samples dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000))) # Set template to prepare encoding - dataset.set_template('Template', model_id='ms://Qwen/Qwen3-4B') + dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B') # Preprocess the dataset to standard format dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) # Encode dataset @@ -43,13 +43,14 @@ def train(): # Global batch size = 8, for GPUs, so 1 sample per GPU dataloader = DataLoader(dataset=dataset, batch_size=8) # Use a TransformersModel - model = TransformersModel(model_id='ms://Qwen/Qwen3-4B') + model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B') + model.model._no_split_modules = {'Qwen3_5DecoderLayer'} lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') # Add a lora to model, with name `default` # Comment this to use full-parameter training - model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2) + # model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2) # Add Optimizer for lora `default` model.set_optimizer(optimizer_cls='AdamW', lr=1e-4) # Add LRScheduler for lora `default` @@ -60,8 +61,8 @@ def train(): logger.info(model.get_train_configs()) logger.info(f'Total steps: {len(dataloader)}') loss_metric = 99.0 - # lora: 18G * 4 - # full: 50G * 4 + # lora: 8G * 8 + # full: 18G * 8 for step, batch in enumerate(dataloader): # Do forward and backward model.forward_backward(inputs=batch) diff --git a/docs/source_en/Usage Guide/Introduction-with-Qwen3.5.md b/docs/source_en/Usage Guide/Introduction-with-Qwen3.5.md index 2f67e37b..b46c9c20 100644 --- a/docs/source_en/Usage Guide/Introduction-with-Qwen3.5.md +++ b/docs/source_en/Usage Guide/Introduction-with-Qwen3.5.md @@ -89,8 +89,8 @@ def train(): logger.info(model.get_train_configs()) logger.info(f'Total steps: {len(dataloader)}') loss_metric = 99.0 - # LoRA training: ~18G * 4 GPU memory - # Full-parameter training: ~50G * 4 GPU memory + # LoRA training: ~8G * 8 GPU memory + # Full-parameter training: ~18G * 8 GPU memory for step, batch in enumerate(dataloader): # Forward + backward pass model.forward_backward(inputs=batch) diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" index 8b86b9b0..b4ca94cd 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" @@ -89,8 +89,8 @@ def train(): logger.info(model.get_train_configs()) logger.info(f'Total steps: {len(dataloader)}') loss_metric = 99.0 - # LoRA 训练:约 18G * 4 显存占用 - # 全参数训练:约 50G * 4 显存占用 + # LoRA 训练:约 8G * 8 显存占用 + # 全参数训练:约 18G * 8 显存占用 for step, batch in enumerate(dataloader): # 前向 + 反向传播 model.forward_backward(inputs=batch) From e7677f5da90adae1b33ad798726479ebac9b30f6 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Fri, 13 Mar 2026 17:05:07 +0800 Subject: [PATCH 09/56] fix --- cookbook/rl/gkd_on_policy.py | 4 +--- src/twinkle/loss/gkd.py | 5 ++--- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/cookbook/rl/gkd_on_policy.py b/cookbook/rl/gkd_on_policy.py index 9dbe8a6a..e652c35b 100644 --- a/cookbook/rl/gkd_on_policy.py +++ b/cookbook/rl/gkd_on_policy.py @@ -128,7 +128,7 @@ def main(): # ── Teacher vLLM sampler (for on-policy generation) ──────────────────────── teacher_sampler = vLLMSampler( model_id=TEACHER_MODEL_ID, - engine_args={'gpu_memory_utilization': 0.85, 'max_model_len': 2048}, + engine_args={'gpu_memory_utilization': 0.85, 'max_model_len': 2048, 'logprobs_mode': 'raw_logprobs'}, device_mesh=sampler_mesh, remote_group='sampler', ) @@ -160,8 +160,6 @@ def main(): # Teacher logits (frozen) teacher_output = teacher_model.forward_only(inputs=input_data) - teacher_output = teacher_output() - # Student forward + GKD backward student_model.forward_backward(inputs=input_data, teacher_output=teacher_output, topk=topk) student_model.clip_grad_and_step() diff --git a/src/twinkle/loss/gkd.py b/src/twinkle/loss/gkd.py index e46b2a46..659f552a 100644 --- a/src/twinkle/loss/gkd.py +++ b/src/twinkle/loss/gkd.py @@ -82,12 +82,11 @@ def __call__( labels = inputs['labels'] student_logits = outputs['logits'] - if teacher_logits.shape[1] > student_logits.shape[1]: - teacher_logits = teacher_logits[:, :student_logits.shape[1]] - # Align seq dimension: some MLLMs return extra prefix logits if student_logits.shape[1] != labels.shape[1]: student_logits = student_logits[:, -labels.shape[1]:] + if teacher_logits.shape[1] > student_logits.shape[1]: + teacher_logits = teacher_logits[:, :student_logits.shape[1]] # Shift labels: label[i] = next token predicted by logits[i] # The last position wraps to label[0] via roll; since label[0] is -100 (prompt), From a4ff6c549c77cb5e31056cf4951f3c875370a7e4 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Fri, 13 Mar 2026 17:35:03 +0800 Subject: [PATCH 10/56] fix --- src/twinkle/data_format/sampling.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/src/twinkle/data_format/sampling.py b/src/twinkle/data_format/sampling.py index 129aea8e..5b46e152 100644 --- a/src/twinkle/data_format/sampling.py +++ b/src/twinkle/data_format/sampling.py @@ -17,21 +17,20 @@ class SamplingParams: top_k: int = -1 top_p: float = 1.0 repetition_penalty: float = 1.0 + logprobs: int = None + prompt_logprobs: int = None + num_samples: int = 1 - def to_vllm(self, *, num_samples: int = 1, logprobs: bool = True, prompt_logprobs: int = 0): + def to_vllm(self, **kwargs): """Convert to vLLM SamplingParams. - - Args: - num_samples: Number of completions per prompt (vLLM's 'n' parameter). - logprobs: Whether to return logprobs for generated tokens. - prompt_logprobs: Number of prompt token logprobs to return. """ from vllm import SamplingParams as VLLMSamplingParams kwargs = { 'temperature': self.temperature, 'top_p': self.top_p, - 'n': num_samples, + 'n': self.num_samples, + **kwargs, } if self.max_tokens is not None: @@ -54,14 +53,14 @@ def to_vllm(self, *, num_samples: int = 1, logprobs: bool = True, prompt_logprob else: kwargs['stop'] = list(self.stop) - if logprobs: - kwargs['logprobs'] = 0 + if self.logprobs is not None: + kwargs['logprobs'] = self.logprobs - if prompt_logprobs > 0: - kwargs['prompt_logprobs'] = prompt_logprobs + if self.prompt_logprobs is not None: + kwargs['prompt_logprobs'] = self.prompt_logprobs vllm_params = VLLMSamplingParams(**kwargs) - if num_samples > 1: + if self.num_samples > 1: from vllm.sampling_params import RequestOutputKind vllm_params.output_kind = RequestOutputKind.FINAL_ONLY return vllm_params From 7c726f7899d8565e5a844b31f5120d7c0f657beb Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Fri, 13 Mar 2026 17:56:56 +0800 Subject: [PATCH 11/56] fix --- cookbook/rl/gkd_on_policy.py | 151 +++++++++++++++++++++++++---------- src/twinkle/loss/gkd.py | 5 +- 2 files changed, 112 insertions(+), 44 deletions(-) diff --git a/cookbook/rl/gkd_on_policy.py b/cookbook/rl/gkd_on_policy.py index e652c35b..012f5e38 100644 --- a/cookbook/rl/gkd_on_policy.py +++ b/cookbook/rl/gkd_on_policy.py @@ -1,43 +1,45 @@ """GKD On-Policy Distillation via Ray. -On-policy knowledge distillation: teacher vLLM generates fresh responses for -each prompt, then the student learns to match the teacher's token distribution. +On-policy knowledge distillation: student vLLM generates responses, +teacher vLLM provides top-k prompt logprobs, then student model learns +to match the teacher's token distribution. Pipeline: 1. DataLoader supplies prompt-only batches. - 2. Teacher vLLM sampler generates completions on-the-fly. - 3. Teacher TransformersModel runs forward_only() to get frozen logits. + 2. Student vLLM sampler generates completions on-the-fly. + 3. Teacher vLLM sampler computes top-k prompt logprobs on generated sequences. 4. Student TransformersModel runs forward_backward() with GKDLoss. Architecture (Ray): ┌─────────────────────────────────────────────────────────────────┐ │ Driver (CPU) │ │ dataloader ──► prompt-only batch │ - │ teacher_sampler.sample() ──► on-policy completions │ - │ teacher_model.forward_only() ──► frozen teacher logits │ - │ student_model.forward_backward(teacher_logits=...) ──► GKD │ + │ student_sampler.sample() ──► on-policy completions │ + │ teacher_sampler.sample(topk_prompt_logprobs=k) ──► teacher lps│ + │ student_model.forward_backward(teacher_output=...) ──► GKD │ └─────────────────────────────────────────────────────────────────┘ │ │ │ - DataLoader vLLMSampler TransformersModel ×2 - (model GPUs) (sampler GPUs) student + teacher (model GPUs) + DataLoader vLLMSampler ×2 TransformersModel + (model GPUs) student + teacher (model GPUs) Environment variables (all optional): STUDENT_MODEL_ID – (default: ms://Qwen/Qwen2.5-1.5B-Instruct) - TEACHER_MODEL_ID – (default: ms://Qwen/Qwen2.5-7B-Instruct) - MODEL_GPUS – GPUs for student + teacher models (default: 4) - SAMPLER_GPUS – GPUs for teacher vLLM sampler (default: 4) - MAX_NEW_TOKENS – max completion tokens (default: 512) - BATCH_SIZE – global prompt-level batch size (default: 8) - MAX_STEPS – total optimisation steps (default: 200) - LR – learning rate (default: 1e-4) - GKD_BETA – JSD beta (0=fwd KL, 1=rev KL) (default: 0.5) - GKD_TEMPERATURE – distillation temperature (default: 1.0) - GKD_TOPK – top-k vocab reduction; 0=full (default: 0) + TEACHER_MODEL_ID – (default: ms://Qwen/Qwen3-4B) + MODEL_GPUS – GPUs for student model (default: 4) + SAMPLER_GPUS – GPUs for each vLLM sampler (default: 2) + MAX_NEW_TOKENS – max completion tokens (default: 512) + BATCH_SIZE – global prompt-level batch size (default: 8) + MAX_STEPS – total optimisation steps (default: 200) + LR – learning rate (default: 1e-4) + GKD_BETA – JSD beta (0=fwd KL, 1=rev KL) (default: 0.5) + GKD_TEMPERATURE – distillation temperature (default: 1.0) + GKD_TOPK – top-k vocab for teacher logprobs (default: 10) """ import os -from typing import List +from typing import List, Optional +import torch from peft import LoraConfig import twinkle @@ -58,8 +60,8 @@ TEACHER_MODEL_ID = os.environ.get('TEACHER_MODEL_ID', 'ms://Qwen/Qwen3-4B') MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) -SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4)) -NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS +SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 2)) +NUM_GPUS = MODEL_GPUS + 2*SAMPLER_GPUS MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 512)) BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) @@ -68,7 +70,7 @@ GKD_BETA = float(os.environ.get('GKD_BETA', 0.5)) GKD_TEMPERATURE = float(os.environ.get('GKD_TEMPERATURE', 1.0)) -GKD_TOPK = int(os.environ.get('GKD_TOPK', 0)) +GKD_TOPK = int(os.environ.get('GKD_TOPK', 10)) ADAPTER_NAME = 'default' @@ -76,7 +78,7 @@ # ── Dataset ─────────────────────────────────────────────────────────────────── def create_dataset(): - """Prompt-only dataset; teacher vLLM will generate completions on-policy.""" + """Prompt-only dataset; student vLLM will generate completions on-policy.""" dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train')) dataset.set_template('Template', model_id=STUDENT_MODEL_ID, max_length=1024) dataset.map(GSM8KProcessor()) @@ -84,12 +86,63 @@ def create_dataset(): return dataset +# ── Utility ─────────────────────────────────────────────────────────────────── + +def convert_topk_prompt_logprobs( + topk_prompt_logprobs_batch: List[List[Optional[List[tuple]]]], + device: str = 'cpu', +) -> dict: + """Convert vLLM topk_prompt_logprobs to GKDLoss teacher_output format. + + Args: + topk_prompt_logprobs_batch: List of per-input topk_prompt_logprobs. + Each is List[Optional[List[(token_id, logprob)]]] of shape [seq_len, topk]. + device: Target device for tensors. + + Returns: + Dict with 'topk_logprobs' [batch, seq_len, topk] and + 'topk_indices' [batch, seq_len, topk] tensors. + """ + batch_logprobs = [] + batch_indices = [] + + for seq_topk in topk_prompt_logprobs_batch: + seq_logprobs = [] + seq_indices = [] + for pos_topk in seq_topk: + if pos_topk is None: + # First position typically has no logprobs + seq_logprobs.append([0.0] * len(seq_topk[1]) if len(seq_topk) > 1 and seq_topk[1] else [0.0]) + seq_indices.append([0] * len(seq_topk[1]) if len(seq_topk) > 1 and seq_topk[1] else [0]) + else: + seq_logprobs.append([lp for _, lp in pos_topk]) + seq_indices.append([tid for tid, _ in pos_topk]) + batch_logprobs.append(seq_logprobs) + batch_indices.append(seq_indices) + + # Pad to same seq_len within batch + max_len = max(len(seq) for seq in batch_logprobs) + topk = len(batch_logprobs[0][0]) if batch_logprobs and batch_logprobs[0] else GKD_TOPK + + for i in range(len(batch_logprobs)): + pad_len = max_len - len(batch_logprobs[i]) + if pad_len > 0: + batch_logprobs[i].extend([[0.0] * topk] * pad_len) + batch_indices[i].extend([[0] * topk] * pad_len) + + return { + 'topk_logprobs': torch.tensor(batch_logprobs, dtype=torch.float32, device=device), + 'topk_indices': torch.tensor(batch_indices, dtype=torch.long, device=device), + } + + # ── Training ────────────────────────────────────────────────────────────────── def main(): device_groups = [ - DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='cuda'), - DeviceGroup(name='sampler', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='cuda'), + DeviceGroup(name='student_model', ranks=MODEL_GPUS, device_type='cuda'), + DeviceGroup(name='student_sampler', ranks=SAMPLER_GPUS, device_type='cuda'), + DeviceGroup(name='teacher_sampler', ranks=SAMPLER_GPUS, device_type='cuda'), ] model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS) sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS) @@ -105,7 +158,7 @@ def main(): student_model = TransformersModel( model_id=STUDENT_MODEL_ID, device_mesh=model_mesh, - remote_group='model', + remote_group='student_model', ) student_model.add_adapter_to_model( ADAPTER_NAME, @@ -117,20 +170,21 @@ def main(): student_model.set_loss(GKDLoss(beta=GKD_BETA, temperature=GKD_TEMPERATURE)) student_model.set_template('Template', model_id=STUDENT_MODEL_ID) - # ── Teacher model (frozen, for logits) ───────────────────────────────────── - teacher_model = TransformersModel( - model_id=TEACHER_MODEL_ID, - device_mesh=model_mesh, - remote_group='model', + # ── Student vLLM sampler (for on-policy generation) ──────────────────────── + student_sampler = vLLMSampler( + model_id=STUDENT_MODEL_ID, + engine_args={'gpu_memory_utilization': 0.85, 'max_model_len': 2048}, + device_mesh=sampler_mesh, + remote_group='student_sampler', ) - teacher_model.set_template('Template', model_id=TEACHER_MODEL_ID) + student_sampler.set_template(Template, model_id=STUDENT_MODEL_ID) - # ── Teacher vLLM sampler (for on-policy generation) ──────────────────────── + # ── Teacher vLLM sampler (for prompt logprobs) ─────────────────────────────── teacher_sampler = vLLMSampler( model_id=TEACHER_MODEL_ID, engine_args={'gpu_memory_utilization': 0.85, 'max_model_len': 2048, 'logprobs_mode': 'raw_logprobs'}, device_mesh=sampler_mesh, - remote_group='sampler', + remote_group='teacher_sampler', ) teacher_sampler.set_template(Template, model_id=TEACHER_MODEL_ID) @@ -140,11 +194,12 @@ def main(): batch_size=BATCH_SIZE, min_batch_size=BATCH_SIZE, device_mesh=model_mesh, - remote_group='model', + remote_group='student_model', ) sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS, temperature=1.0) - topk = GKD_TOPK if GKD_TOPK > 0 else None + # For teacher: only need prompt logprobs, no generation + teacher_sampling_params = SamplingParams(max_tokens=1, temperature=1.0, prompt_logprobs=10) logger.info(f'GKD On-Policy | student={STUDENT_MODEL_ID} teacher={TEACHER_MODEL_ID}') logger.info(f' beta={GKD_BETA} T={GKD_TEMPERATURE} topk={GKD_TOPK}') @@ -154,14 +209,24 @@ def main(): if optim_step >= MAX_STEPS: break - # Teacher vLLM generates completions - sample_response = teacher_sampler.sample(batch, sampling_params, num_samples=1) + # 1. Student vLLM generates completions + sample_response = student_sampler.sample(batch, sampling_params, num_samples=1) input_data = [seq.new_input_feature for seq in sample_response.sequences] - # Teacher logits (frozen) - teacher_output = teacher_model.forward_only(inputs=input_data) - # Student forward + GKD backward - student_model.forward_backward(inputs=input_data, teacher_output=teacher_output, topk=topk) + # 2. Teacher vLLM computes top-k prompt logprobs on generated sequences + teacher_response = teacher_sampler.sample( + input_data, + teacher_sampling_params, + ) + + # 3. Convert teacher logprobs to tensor format for GKDLoss + teacher_output = convert_topk_prompt_logprobs( + teacher_response.topk_prompt_logprobs, + device='cuda', + ) + + # 4. Student forward + GKD backward + student_model.forward_backward(inputs=input_data, teacher_output=teacher_output) student_model.clip_grad_and_step() optim_step += 1 diff --git a/src/twinkle/loss/gkd.py b/src/twinkle/loss/gkd.py index 659f552a..44249dfb 100644 --- a/src/twinkle/loss/gkd.py +++ b/src/twinkle/loss/gkd.py @@ -85,8 +85,11 @@ def __call__( # Align seq dimension: some MLLMs return extra prefix logits if student_logits.shape[1] != labels.shape[1]: student_logits = student_logits[:, -labels.shape[1]:] - if teacher_logits.shape[1] > student_logits.shape[1]: + if teacher_logits is not None and teacher_logits.shape[1] > student_logits.shape[1]: teacher_logits = teacher_logits[:, :student_logits.shape[1]] + if teacher_topk_logprobs is not None and teacher_topk_logprobs.shape[1] > student_logits.shape[1]: + teacher_topk_logprobs = teacher_topk_logprobs[:, :student_logits.shape[1]] + teacher_topk_indices = teacher_topk_indices[:, :student_logits.shape[1]] # Shift labels: label[i] = next token predicted by logits[i] # The last position wraps to label[0] via roll; since label[0] is -100 (prompt), From 4e6ac60db0db32c67662f5b9477f055e2158bcb4 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Fri, 13 Mar 2026 18:20:40 +0800 Subject: [PATCH 12/56] fix --- cookbook/rl/gkd_on_policy.py | 3 ++- src/twinkle/data_format/sampling.py | 18 +++++++----------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/cookbook/rl/gkd_on_policy.py b/cookbook/rl/gkd_on_policy.py index 012f5e38..85690c4d 100644 --- a/cookbook/rl/gkd_on_policy.py +++ b/cookbook/rl/gkd_on_policy.py @@ -199,7 +199,7 @@ def main(): sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS, temperature=1.0) # For teacher: only need prompt logprobs, no generation - teacher_sampling_params = SamplingParams(max_tokens=1, temperature=1.0, prompt_logprobs=10) + teacher_sampling_params = SamplingParams(max_tokens=1, temperature=1.0) logger.info(f'GKD On-Policy | student={STUDENT_MODEL_ID} teacher={TEACHER_MODEL_ID}') logger.info(f' beta={GKD_BETA} T={GKD_TEMPERATURE} topk={GKD_TOPK}') @@ -217,6 +217,7 @@ def main(): teacher_response = teacher_sampler.sample( input_data, teacher_sampling_params, + prompt_logprobs=10, ) # 3. Convert teacher logprobs to tensor format for GKDLoss diff --git a/src/twinkle/data_format/sampling.py b/src/twinkle/data_format/sampling.py index 5b46e152..2ab5602b 100644 --- a/src/twinkle/data_format/sampling.py +++ b/src/twinkle/data_format/sampling.py @@ -17,11 +17,8 @@ class SamplingParams: top_k: int = -1 top_p: float = 1.0 repetition_penalty: float = 1.0 - logprobs: int = None - prompt_logprobs: int = None - num_samples: int = 1 - def to_vllm(self, **kwargs): + def to_vllm(self, *, num_samples: int = 1, logprobs: int = None, prompt_logprobs: int = 0): """Convert to vLLM SamplingParams. """ from vllm import SamplingParams as VLLMSamplingParams @@ -29,8 +26,7 @@ def to_vllm(self, **kwargs): kwargs = { 'temperature': self.temperature, 'top_p': self.top_p, - 'n': self.num_samples, - **kwargs, + 'n': num_samples, } if self.max_tokens is not None: @@ -53,14 +49,14 @@ def to_vllm(self, **kwargs): else: kwargs['stop'] = list(self.stop) - if self.logprobs is not None: - kwargs['logprobs'] = self.logprobs + if logprobs is not None: + kwargs['logprobs'] = logprobs - if self.prompt_logprobs is not None: - kwargs['prompt_logprobs'] = self.prompt_logprobs + if prompt_logprobs is not None: + kwargs['prompt_logprobs'] = prompt_logprobs vllm_params = VLLMSamplingParams(**kwargs) - if self.num_samples > 1: + if num_samples > 1: from vllm.sampling_params import RequestOutputKind vllm_params.output_kind = RequestOutputKind.FINAL_ONLY return vllm_params From 5ced90847a908b8f64f58c2c88100b202954d8cb Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sat, 14 Mar 2026 15:00:30 +0800 Subject: [PATCH 13/56] fix --- cookbook/rl/gkd_on_policy.py | 13 +++++-------- src/twinkle/data_format/sampling.py | 18 +++++++++++------- .../sampler/vllm_sampler/vllm_engine.py | 7 ++----- .../sampler/vllm_sampler/vllm_sampler.py | 4 ---- 4 files changed, 18 insertions(+), 24 deletions(-) diff --git a/cookbook/rl/gkd_on_policy.py b/cookbook/rl/gkd_on_policy.py index 85690c4d..36e2fa3f 100644 --- a/cookbook/rl/gkd_on_policy.py +++ b/cookbook/rl/gkd_on_policy.py @@ -197,10 +197,6 @@ def main(): remote_group='student_model', ) - sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS, temperature=1.0) - # For teacher: only need prompt logprobs, no generation - teacher_sampling_params = SamplingParams(max_tokens=1, temperature=1.0) - logger.info(f'GKD On-Policy | student={STUDENT_MODEL_ID} teacher={TEACHER_MODEL_ID}') logger.info(f' beta={GKD_BETA} T={GKD_TEMPERATURE} topk={GKD_TOPK}') @@ -210,14 +206,15 @@ def main(): break # 1. Student vLLM generates completions - sample_response = student_sampler.sample(batch, sampling_params, num_samples=1) + sample_response = student_sampler.sample(batch, SamplingParams(max_tokens=MAX_NEW_TOKENS, temperature=1.0), num_samples=1) input_data = [seq.new_input_feature for seq in sample_response.sequences] - + for data in input_data: + data.pop('input_ids', None) + # 2. Teacher vLLM computes top-k prompt logprobs on generated sequences teacher_response = teacher_sampler.sample( input_data, - teacher_sampling_params, - prompt_logprobs=10, + SamplingParams(max_tokens=1, temperature=1.0, prompt_logprobs=10), ) # 3. Convert teacher logprobs to tensor format for GKDLoss diff --git a/src/twinkle/data_format/sampling.py b/src/twinkle/data_format/sampling.py index 2ab5602b..5b46e152 100644 --- a/src/twinkle/data_format/sampling.py +++ b/src/twinkle/data_format/sampling.py @@ -17,8 +17,11 @@ class SamplingParams: top_k: int = -1 top_p: float = 1.0 repetition_penalty: float = 1.0 + logprobs: int = None + prompt_logprobs: int = None + num_samples: int = 1 - def to_vllm(self, *, num_samples: int = 1, logprobs: int = None, prompt_logprobs: int = 0): + def to_vllm(self, **kwargs): """Convert to vLLM SamplingParams. """ from vllm import SamplingParams as VLLMSamplingParams @@ -26,7 +29,8 @@ def to_vllm(self, *, num_samples: int = 1, logprobs: int = None, prompt_logprobs kwargs = { 'temperature': self.temperature, 'top_p': self.top_p, - 'n': num_samples, + 'n': self.num_samples, + **kwargs, } if self.max_tokens is not None: @@ -49,14 +53,14 @@ def to_vllm(self, *, num_samples: int = 1, logprobs: int = None, prompt_logprobs else: kwargs['stop'] = list(self.stop) - if logprobs is not None: - kwargs['logprobs'] = logprobs + if self.logprobs is not None: + kwargs['logprobs'] = self.logprobs - if prompt_logprobs is not None: - kwargs['prompt_logprobs'] = prompt_logprobs + if self.prompt_logprobs is not None: + kwargs['prompt_logprobs'] = self.prompt_logprobs vllm_params = VLLMSamplingParams(**kwargs) - if num_samples > 1: + if self.num_samples > 1: from vllm.sampling_params import RequestOutputKind vllm_params.output_kind = RequestOutputKind.FINAL_ONLY return vllm_params diff --git a/src/twinkle/sampler/vllm_sampler/vllm_engine.py b/src/twinkle/sampler/vllm_sampler/vllm_engine.py index a12a1e40..1d4ec331 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_engine.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_engine.py @@ -190,6 +190,7 @@ async def sample( *, images: Optional[List[Any]] = None, videos: Optional[List[Any]] = None, + **kwargs ) -> SampleResponse: """ Sample completions from the model. @@ -220,11 +221,7 @@ async def sample( if isinstance(sampling_params, dict): sampling_params = SamplingParams.from_dict(sampling_params) prompt_logprobs_k = topk_prompt_logprobs if topk_prompt_logprobs > 0 else (1 if include_prompt_logprobs else 0) - vllm_params = sampling_params.to_vllm( - num_samples=num_samples, - logprobs=logprobs, - prompt_logprobs=prompt_logprobs_k, - ) + vllm_params = sampling_params.to_vllm(**kwargs) # Build request if request_id is None: diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py index b4d1c6fd..c3ac44f6 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py @@ -275,8 +275,6 @@ def sample( adapter_name: str = '', adapter_path: Optional[str] = None, *, - logprobs: bool = True, - num_samples: int = 1, return_encoded: bool = False, ) -> SampleResponse: """Sample responses for given inputs. @@ -337,8 +335,6 @@ async def _sample_all(): feat, sampling_params, lora_request=lora_request, - logprobs=logprobs, - num_samples=num_samples, ) for feat in encoded_inputs ] return await asyncio.gather(*tasks) From 1524dbf33085b51972bfd19bc618591eed919769 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sat, 14 Mar 2026 20:33:06 +0800 Subject: [PATCH 14/56] wip --- cookbook/rl/gkd_on_policy.py | 4 +- .../sampler/vllm_sampler/vllm_engine.py | 6 +-- .../sampler/vllm_sampler/vllm_sampler.py | 42 +++---------------- 3 files changed, 9 insertions(+), 43 deletions(-) diff --git a/cookbook/rl/gkd_on_policy.py b/cookbook/rl/gkd_on_policy.py index 36e2fa3f..2e5d0f47 100644 --- a/cookbook/rl/gkd_on_policy.py +++ b/cookbook/rl/gkd_on_policy.py @@ -206,8 +206,8 @@ def main(): break # 1. Student vLLM generates completions - sample_response = student_sampler.sample(batch, SamplingParams(max_tokens=MAX_NEW_TOKENS, temperature=1.0), num_samples=1) - input_data = [seq.new_input_feature for seq in sample_response.sequences] + sample_response = student_sampler.sample(batch, SamplingParams(max_tokens=MAX_NEW_TOKENS, temperature=1.0, num_samples=1)) + input_data = [seq.new_input_feature for response in sample_response for seq in response.sequences] for data in input_data: data.pop('input_ids', None) diff --git a/src/twinkle/sampler/vllm_sampler/vllm_engine.py b/src/twinkle/sampler/vllm_sampler/vllm_engine.py index 1d4ec331..66f01fe8 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_engine.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_engine.py @@ -180,10 +180,6 @@ async def sample( self, prompt_token_ids: List[int], sampling_params: Union[SamplingParams, Dict[str, Any]], - num_samples: int = 1, - logprobs: bool = True, - include_prompt_logprobs: bool = False, - topk_prompt_logprobs: int = 0, lora_request: Optional[Any] = None, request_id: Optional[str] = None, priority: int = 0, @@ -220,7 +216,7 @@ async def sample( # Convert to vLLM params if isinstance(sampling_params, dict): sampling_params = SamplingParams.from_dict(sampling_params) - prompt_logprobs_k = topk_prompt_logprobs if topk_prompt_logprobs > 0 else (1 if include_prompt_logprobs else 0) + prompt_logprobs_k = sampling_params.prompt_logprobs or 0 vllm_params = sampling_params.to_vllm(**kwargs) # Build request diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py index c3ac44f6..e33e178a 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py @@ -35,29 +35,6 @@ logger = get_logger() -def _collect_sample_responses(results: List[SampleResponse], **kwargs) -> SampleResponse: - """Custom collect function to merge multiple SampleResponse objects. - - Args: - results: List of SampleResponse from each DP worker. - - Returns: - Merged SampleResponse with all sequences combined. - """ - if not results: - return SampleResponse(sequences=[]) - - if len(results) == 1: - return results[0] - - all_sequences = [] - for resp in results: - if resp is not None and hasattr(resp, 'sequences'): - all_sequences.extend(resp.sequences) - - return SampleResponse(sequences=all_sequences) - - @remote_class() class vLLMSampler(Sampler, CheckpointEngineMixin): """A vLLM-based sampler using VLLMEngine (AsyncLLM). @@ -224,7 +201,6 @@ async def _sample_single( lora_request: Optional[Any] = None, *, logprobs: bool = True, - num_samples: int = 1, ) -> List[SampledSequence]: """Sample a single input asynchronously. @@ -250,14 +226,13 @@ async def _sample_single( prompt_token_ids=input_ids, sampling_params=sampling_params, logprobs=logprobs, - num_samples=num_samples, lora_request=lora_request, images=images, videos=videos, ) # response.sequences contains num_samples sequences for this prompt - return [ + return SampleResponse(sequences=[ SampledSequence( stop_reason=seq.stop_reason, tokens=seq.tokens, @@ -265,9 +240,9 @@ async def _sample_single( decoded=self.template.decode(seq.tokens), new_input_feature=self.template.concat_input_feature(feat, seq.tokens), ) for seq in response.sequences - ] + ], prompt_logprobs=response.prompt_logprobs, topk_prompt_logprobs=response.topk_prompt_logprobs) - @remote_function(dispatch='slice_dp', collect=_collect_sample_responses, lazy_collect=False) + @remote_function(dispatch='slice_dp', collect='flatten', lazy_collect=False) def sample( self, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], @@ -276,7 +251,7 @@ def sample( adapter_path: Optional[str] = None, *, return_encoded: bool = False, - ) -> SampleResponse: + ) -> List[SampleResponse]: """Sample responses for given inputs. Args: @@ -300,7 +275,6 @@ def sample( Note: In Ray mode with multiple workers (DP > 1): - Data is automatically sliced by DP rank (dispatch='slice_dp') - - Results are merged using _collect_sample_responses - Each worker receives already-sliced inputs (e.g., DP4 with 8 inputs -> 2 per worker) """ if sampling_params is None: @@ -339,12 +313,8 @@ async def _sample_all(): ] return await asyncio.gather(*tasks) - results = self._run_in_loop(_sample_all()) - # Flatten results (each result contains num_samples sequences) - all_sequences = [] - for seqs in results: - all_sequences.extend(seqs) - return SampleResponse(sequences=all_sequences) + sample_results = self._run_in_loop(_sample_all()) + return sample_results @remote_function(dispatch='all', collect='first') def sleep(self, level: int = 1) -> None: From 30df960ef5828808f0e0f7a62d0b9951f7f338d3 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sat, 14 Mar 2026 20:37:16 +0800 Subject: [PATCH 15/56] fix --- cookbook/rl/gkd_on_policy.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cookbook/rl/gkd_on_policy.py b/cookbook/rl/gkd_on_policy.py index 2e5d0f47..c27f597c 100644 --- a/cookbook/rl/gkd_on_policy.py +++ b/cookbook/rl/gkd_on_policy.py @@ -218,8 +218,9 @@ def main(): ) # 3. Convert teacher logprobs to tensor format for GKDLoss + # teacher_response is List[SampleResponse], extract topk_prompt_logprobs from each teacher_output = convert_topk_prompt_logprobs( - teacher_response.topk_prompt_logprobs, + [resp.topk_prompt_logprobs for resp in teacher_response], device='cuda', ) From 43be0f852a3b6782722357f527bc83c43f12c46d Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sat, 14 Mar 2026 21:42:34 +0800 Subject: [PATCH 16/56] wip --- cookbook/rl/gkd_on_policy.py | 7 +++---- src/twinkle/loss/gkd.py | 25 ++++++++++++++----------- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/cookbook/rl/gkd_on_policy.py b/cookbook/rl/gkd_on_policy.py index c27f597c..d78a0e33 100644 --- a/cookbook/rl/gkd_on_policy.py +++ b/cookbook/rl/gkd_on_policy.py @@ -131,8 +131,8 @@ def convert_topk_prompt_logprobs( batch_indices[i].extend([[0] * topk] * pad_len) return { - 'topk_logprobs': torch.tensor(batch_logprobs, dtype=torch.float32, device=device), - 'topk_indices': torch.tensor(batch_indices, dtype=torch.long, device=device), + 'teacher_topk_logprobs': torch.tensor(batch_logprobs, dtype=torch.float32, device=device), + 'teacher_topk_indices': torch.tensor(batch_indices, dtype=torch.long, device=device), } @@ -221,11 +221,10 @@ def main(): # teacher_response is List[SampleResponse], extract topk_prompt_logprobs from each teacher_output = convert_topk_prompt_logprobs( [resp.topk_prompt_logprobs for resp in teacher_response], - device='cuda', ) # 4. Student forward + GKD backward - student_model.forward_backward(inputs=input_data, teacher_output=teacher_output) + student_model.forward_backward(inputs=input_data, **teacher_output) student_model.clip_grad_and_step() optim_step += 1 diff --git a/src/twinkle/loss/gkd.py b/src/twinkle/loss/gkd.py index 44249dfb..93fbb9de 100644 --- a/src/twinkle/loss/gkd.py +++ b/src/twinkle/loss/gkd.py @@ -49,7 +49,9 @@ def __call__( inputs, outputs, *, - teacher_output: Optional['torch.Tensor'] = None, + teacher_logits: Optional['torch.Tensor'] = None, + teacher_topk_logprobs: Optional['torch.Tensor'] = None, + teacher_topk_indices: Optional['torch.Tensor'] = None, topk: Optional[int] = None, **kwargs, ) -> LossOutput: @@ -58,22 +60,18 @@ def __call__( Args: inputs: Dict containing 'labels' [batch, seq_len] with ignore_index for non-response tokens. outputs: Dict containing 'logits' [batch, seq_len, vocab_size] from the student model. - teacher_output: A dict contains: - teacher_logits: [batch, seq_len, vocab_size] full vocabulary logits from a local teacher. - Either teacher_logits or (teacher_topk_logprobs + teacher_topk_indices) - must be provided. - teacher_topk_logprobs: [batch, seq_len, topk] log-probs from a remote teacher API. - Returned by a vLLM-compatible /v1/completions prompt_logprobs call. - teacher_topk_indices: [batch, seq_len, topk] token indices corresponding to teacher_topk_logprobs. + teacher_logits: [batch, seq_len, vocab_size] full vocabulary logits from a local teacher. + Either teacher_logits or (teacher_topk_logprobs + teacher_topk_indices) + must be provided. + teacher_topk_logprobs: [batch, seq_len, topk] log-probs from a remote teacher API. + Returned by a vLLM-compatible /v1/completions prompt_logprobs call. + teacher_topk_indices: [batch, seq_len, topk] token indices corresponding to teacher_topk_logprobs. topk: If set together with teacher_logits, only the top-k teacher tokens are used to reduce vocabulary size before computing the JSD (memory-efficient local teacher mode). Returns: LossOutput with scalar 'loss' averaged over valid (non-ignored) response tokens. """ - teacher_logits = teacher_output.get('logits') - teacher_topk_logprobs = teacher_output.get('topk_logprobs') - teacher_topk_indices = teacher_output.get('topk_indices') assert teacher_logits is not None or ( teacher_topk_logprobs is not None and teacher_topk_indices is not None ), ( @@ -152,12 +150,15 @@ def _generalized_jsd_loss( if teacher_topk_logprobs is not None and teacher_topk_indices is not None: # Remote API teacher: teacher already provides top-k log-probs (T=1). # Gather student logits at teacher's top-k indices, then scale in-place. + teacher_topk_indices = teacher_topk_indices.to(student_logits.device) + teacher_topk_logprobs = teacher_topk_logprobs.to(student_logits.device) student_logits = torch.gather(student_logits, dim=-1, index=teacher_topk_indices) student_logits.div_(temperature) teacher_logits = teacher_topk_logprobs / temperature temperature = 1.0 elif topk is not None and teacher_logits is not None: # Local teacher: select top-k from teacher, gather corresponding student logits + teacher_logits = teacher_logits.to(student_logits.device()) teacher_logits, topk_idx = torch.topk(teacher_logits, k=topk, dim=-1) teacher_logits.div_(temperature) student_logits = torch.gather(student_logits, dim=-1, index=topk_idx) @@ -177,6 +178,8 @@ def _generalized_jsd_loss( elif stu_dim > tea_dim: teacher_logits = F.pad(teacher_logits, (0, stu_dim - tea_dim)) teacher_logits[..., tea_dim:] = student_logits[..., tea_dim:] + if student_logits.shape[1] != mask.shape[1]: + breakpoint() student_logits = student_logits[mask] # [num_valid, vocab/topk] teacher_logits = teacher_logits[mask] num_valid = mask.sum() From 04493405c63d4829d54bd61c062378a2e2a366c5 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 15 Mar 2026 18:15:01 +0800 Subject: [PATCH 17/56] no message --- src/twinkle/template/base.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index 33e970e8..e7262b91 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -52,6 +52,7 @@ def __init__(self, self._test_support_assistant_tokens_mask() self.pre_pipeline: List[Callable[[Trajectory], List[Trajectory]]] = [ self._add_default_system, # Add a default system field + self._to_standard_reasoning_content, # Convert thinking to standard field self._build_mm_messages, # turn to standard mm messages ] self.post_pipeline: List[Callable[[InputFeature], List[InputFeature]]] = [ @@ -183,6 +184,36 @@ def _add_default_system(self, trajectory: Trajectory) -> List[Trajectory]: messages.insert(0, Message(role='system', content=self.default_system)) return [trajectory] + def _to_standard_reasoning_content(self, trajectory: Trajectory) -> List[Trajectory]: + + def _extract_reasoning_content(messages: list[Message]) -> List[Message]: + result = [] + for message in messages: + message = message.copy() + if message.get("role") == "assistant": + content = message.get("content", "") + if "reasoning_content" not in message and isinstance(content, str): + if "" in content: + reasoning_content = content.split("")[0].rstrip("\n").split("")[-1].lstrip( + "\n") + new_content = content.split("")[-1].lstrip("\n") + + message["reasoning_content"] = reasoning_content + message["content"] = new_content + + result.append(message) + + return result + + trajectory['messages'] = _extract_reasoning_content(trajectory['messages']) + extra_messages = trajectory.get('extend_message', []) + if extra_messages: + result = [] + for key, extra_message in trajectory.get('extend_message', []): + result.append((key, _extract_reasoning_content(extra_message))) + trajectory['extend_message'] = result + return [trajectory] + def _check_max_length(self, input_feature: InputFeature) -> List[InputFeature]: if self.max_length and len(input_feature['input_ids']) > self.max_length: if self.truncation_strategy == 'raise': From 4296d62da2e325d2b03caa1f4f046fea82a305e3 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 15 Mar 2026 19:01:38 +0800 Subject: [PATCH 18/56] wip --- src/twinkle/template/base.py | 7 +- src/twinkle/template/utils.py | 248 ++++++++++++++++++---------------- 2 files changed, 135 insertions(+), 120 deletions(-) diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index e7262b91..60c8b723 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -9,7 +9,7 @@ from twinkle.data_format import InputFeature, Message, Trajectory from twinkle.hub import HubOperation from twinkle.utils import load_image, to_device -from .utils import tokenize_with_assistant_labels, transfer_to_standard_message +from .utils import tokenize_with_assistant_labels, transfer_to_standard_message, TokenizeByRound if TYPE_CHECKING: import torch @@ -298,8 +298,9 @@ def encode(self, trajectory: Trajectory, add_generation_prompt: bool = False) -> assistant_masks = encoded.pop('assistant_masks') labels = np.where(assistant_masks, input_ids, -100) else: - input_ids, labels, encoded = tokenize_with_assistant_labels(self.tokenizer, self._apply_chat_template, - trajectory) + input_ids, labels, encoded = TokenizeByRound.tokenize_with_assistant_labels(self.tokenizer, + self._apply_chat_template, + trajectory) else: assert len(trajectory['messages']) == 1 and trajectory['messages'][0]['role'] == 'user' text = trajectory['messages'][0]['content'] diff --git a/src/twinkle/template/utils.py b/src/twinkle/template/utils.py index ccdcfc91..53cdecab 100644 --- a/src/twinkle/template/utils.py +++ b/src/twinkle/template/utils.py @@ -11,59 +11,6 @@ _T = TypeVar('_T') -PLACEHOLDER = '<<>>' - - -def find_subsequence(seq: List[int], subseq: List[int], start: int = 0) -> int: - """Find the first index of `subseq`""" - subseq_len = len(subseq) - for i in range(start, len(seq) - subseq_len + 1): - if seq[i:i + subseq_len] == subseq: - return i - return -1 - - -def split_by_subsequence(seq: List[int], subseq: List[int]) -> List[List[int]]: - """Split seq by subseq""" - parts = [] - start = 0 - subseq_len = len(subseq) - - while True: - pos = find_subsequence(seq, subseq, start) - if pos == -1: - parts.append(seq[start:]) - break - parts.append(seq[start:pos]) - start = pos + subseq_len - - return parts - - -def build_labels( - full_ids: List[int], - template_parts: List[List[int]], -) -> List[int]: - labels = list(full_ids) - pos = 0 - - for part in template_parts: - if not part: - continue - - match_pos = find_subsequence(full_ids, part, pos) - - if match_pos == -1: - # should not happen - raise ValueError(f'Template part not found in full_ids at position {pos}') - - for i in range(match_pos, match_pos + len(part)): - labels[i] = -100 - - pos = match_pos + len(part) - - return labels - def _convert_to_vlm_format(messages: List[Dict]) -> List[Dict]: converted = [] @@ -83,70 +30,6 @@ def _is_vlm_processor(tokenizer) -> bool: return False -def tokenize_with_assistant_labels( - tokenizer: 'PreTrainedTokenizer', - encode_func: Callable, - trajectory: Trajectory, - placeholder: str = PLACEHOLDER, -) -> Tuple[List[int], List[int], Dict[str, Any]]: - import torch - messages = [dict(message) for message in trajectory['messages']] - - _dummy_messages = [] - assistant_count = 0 - for msg in messages: - if msg['role'] == 'assistant': - msg = deepcopy(msg) - if isinstance(msg['content'], str): - msg['content'] = placeholder - else: - msg['content'][0]['text'] = placeholder - assistant_count += 1 - _dummy_messages.append(msg) - - encoded = encode_func(trajectory) - full_ids = encoded.pop('input_ids') - if isinstance(full_ids, torch.Tensor): - full_ids = full_ids.tolist()[0] - - _dummy_trajectory = copy(trajectory) - _dummy_trajectory['messages'] = _dummy_messages - template_ids = encode_func(_dummy_trajectory) - template_ids = template_ids['input_ids'] - if isinstance(template_ids, torch.Tensor): - template_ids = template_ids.tolist()[0] - - extra_kwargs = {} - if 'add_special_tokens' in inspect.signature(tokenizer.encode).parameters: - extra_kwargs['add_special_tokens'] = False - placeholder_ids = tokenizer.encode(placeholder, **extra_kwargs) - template_parts = split_by_subsequence(template_ids, placeholder_ids) - - if len(template_parts) != assistant_count + 1: - raise ValueError(f'Expected {assistant_count + 1} parts, got {len(template_parts)}. ' - 'Placeholder might appear in original content.') - - try: - labels = build_labels(full_ids, template_parts) - except ValueError as e: - newline_placeholder_ids = tokenizer.encode('\n' + placeholder, **extra_kwargs) - template_parts = split_by_subsequence(template_ids, newline_placeholder_ids) - if len(template_parts) == assistant_count + 1: - labels = build_labels(full_ids, template_parts) - else: - raise e - if labels and labels[-1] == -100: - end_idx = len(labels) - start_idx = end_idx - 1 - while start_idx > 0 and labels[start_idx - 1] == -100: - start_idx -= 1 - - for i in range(start_idx, end_idx): - labels[i] = full_ids[i] - - return full_ids, labels, encoded - - def _load_image(img: Any) -> Optional[Any]: """Load images to PIL format.""" import io @@ -297,3 +180,134 @@ def get_inputs_embeds_hf(inputs_embeds, inputs, visual, processor, config): video_mask = video_mask.to(inputs_embeds.device) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) return inputs_embeds + + +class TokenizeByRound: + + @staticmethod + def tokenize_with_assistant_labels( + tokenizer: 'PreTrainedTokenizer', + encode_func: Callable, + trajectory: Trajectory, + ) -> Tuple[List[int], List[int], Dict[str, Any]]: + pass + + +class TokenizeByPlaceHolder: + + PLACEHOLDER = '<<>>' + + @staticmethod + def find_subsequence(seq: List[int], subseq: List[int], start: int = 0) -> int: + """Find the first index of `subseq`""" + subseq_len = len(subseq) + for i in range(start, len(seq) - subseq_len + 1): + if seq[i:i + subseq_len] == subseq: + return i + return -1 + + @staticmethod + def split_by_subsequence(seq: List[int], subseq: List[int]) -> List[List[int]]: + """Split seq by subseq""" + parts = [] + start = 0 + subseq_len = len(subseq) + + while True: + pos = TokenizeByPlaceHolder.find_subsequence(seq, subseq, start) + if pos == -1: + parts.append(seq[start:]) + break + parts.append(seq[start:pos]) + start = pos + subseq_len + + return parts + + @staticmethod + def build_labels( + full_ids: List[int], + template_parts: List[List[int]], + ) -> List[int]: + labels = list(full_ids) + pos = 0 + + for part in template_parts: + if not part: + continue + + match_pos = TokenizeByPlaceHolder.find_subsequence(full_ids, part, pos) + + if match_pos == -1: + # should not happen + raise ValueError(f'Template part not found in full_ids at position {pos}') + + for i in range(match_pos, match_pos + len(part)): + labels[i] = -100 + + pos = match_pos + len(part) + + return labels + + @staticmethod + def tokenize_with_assistant_labels( + tokenizer: 'PreTrainedTokenizer', + encode_func: Callable, + trajectory: Trajectory, + ) -> Tuple[List[int], List[int], Dict[str, Any]]: + import torch + placeholder: str = TokenizeByPlaceHolder.PLACEHOLDER + messages = [dict(message) for message in trajectory['messages']] + + _dummy_messages = [] + assistant_count = 0 + for msg in messages: + if msg['role'] == 'assistant': + msg = deepcopy(msg) + if isinstance(msg['content'], str): + msg['content'] = placeholder + else: + msg['content'][0]['text'] = placeholder + assistant_count += 1 + _dummy_messages.append(msg) + + encoded = encode_func(trajectory) + full_ids = encoded.pop('input_ids') + if isinstance(full_ids, torch.Tensor): + full_ids = full_ids.tolist()[0] + + _dummy_trajectory = copy(trajectory) + _dummy_trajectory['messages'] = _dummy_messages + template_ids = encode_func(_dummy_trajectory) + template_ids = template_ids['input_ids'] + if isinstance(template_ids, torch.Tensor): + template_ids = template_ids.tolist()[0] + + extra_kwargs = {} + if 'add_special_tokens' in inspect.signature(tokenizer.encode).parameters: + extra_kwargs['add_special_tokens'] = False + placeholder_ids = tokenizer.encode(placeholder, **extra_kwargs) + template_parts = TokenizeByPlaceHolder.split_by_subsequence(template_ids, placeholder_ids) + + if len(template_parts) != assistant_count + 1: + raise ValueError(f'Expected {assistant_count + 1} parts, got {len(template_parts)}. ' + 'Placeholder might appear in original content.') + + try: + labels = TokenizeByPlaceHolder.build_labels(full_ids, template_parts) + except ValueError as e: + newline_placeholder_ids = tokenizer.encode('\n' + placeholder, **extra_kwargs) + template_parts = TokenizeByPlaceHolder.split_by_subsequence(template_ids, newline_placeholder_ids) + if len(template_parts) == assistant_count + 1: + labels = TokenizeByPlaceHolder.build_labels(full_ids, template_parts) + else: + raise e + if labels and labels[-1] == -100: + end_idx = len(labels) + start_idx = end_idx - 1 + while start_idx > 0 and labels[start_idx - 1] == -100: + start_idx -= 1 + + for i in range(start_idx, end_idx): + labels[i] = full_ids[i] + + return full_ids, labels, encoded From 39f944925dbd55aef2108bffe7d65e36bf3cbe3d Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 15 Mar 2026 19:07:10 +0800 Subject: [PATCH 19/56] fix --- src/twinkle/template/utils.py | 68 ++++++++++++++++++++++++++++++++++- 1 file changed, 67 insertions(+), 1 deletion(-) diff --git a/src/twinkle/template/utils.py b/src/twinkle/template/utils.py index 53cdecab..5462105f 100644 --- a/src/twinkle/template/utils.py +++ b/src/twinkle/template/utils.py @@ -183,6 +183,15 @@ def get_inputs_embeds_hf(inputs_embeds, inputs, visual, processor, config): class TokenizeByRound: + """Tokenize by encoding messages round-by-round. + + This approach handles tags correctly by encoding each message + incrementally, determining token boundaries by comparing consecutive encode results. + + Unlike TokenizeByPlaceHolder which uses dummy placeholders and may fail when + reasoning_content causes extra tags, this method directly encodes + the trajectory up to each message and compares lengths to find token ranges. + """ @staticmethod def tokenize_with_assistant_labels( @@ -190,7 +199,64 @@ def tokenize_with_assistant_labels( encode_func: Callable, trajectory: Trajectory, ) -> Tuple[List[int], List[int], Dict[str, Any]]: - pass + """Tokenize trajectory and generate labels for assistant turns. + + Args: + tokenizer: The tokenizer (used for decoding if needed). + encode_func: Function to encode a trajectory. + trajectory: The trajectory containing messages. + + Returns: + Tuple of (input_ids, labels, extra_encoded_fields). + Labels are -100 for non-assistant tokens, original token id for assistant tokens. + """ + import torch + messages = trajectory['messages'] + + # Encode full trajectory + encoded = encode_func(trajectory) + full_ids = encoded.pop('input_ids') + if isinstance(full_ids, torch.Tensor): + full_ids = full_ids.tolist()[0] + + # Initialize labels: all -100 (not trained) + labels = [-100] * len(full_ids) + + if not messages: + return full_ids, labels, encoded + + # Encode round by round to find token boundaries + # prev_len tracks where the previous message ended + prev_len = 0 + + for i, msg in enumerate(messages): + # Create partial trajectory up to current message (inclusive) + partial_trajectory = copy(trajectory) + partial_trajectory['messages'] = list(messages[:i + 1]) + + # Encode partial trajectory + partial_encoded = encode_func(partial_trajectory) + partial_ids = partial_encoded['input_ids'] + if isinstance(partial_ids, torch.Tensor): + partial_ids = partial_ids.tolist()[0] + + curr_len = len(partial_ids) + + # If this is an assistant message, mark those tokens as trainable + if msg['role'] == 'assistant': + # Tokens from prev_len to curr_len belong to this assistant turn + for j in range(prev_len, min(curr_len, len(full_ids))): + labels[j] = full_ids[j] + + prev_len = curr_len + + # Handle any remaining tokens after the last message (e.g., EOS added by full encode) + # If the last message was assistant, these trailing tokens should also be trainable + if messages and messages[-1]['role'] == 'assistant': + for j in range(prev_len, len(full_ids)): + labels[j] = full_ids[j] + + return full_ids, labels, encoded class TokenizeByPlaceHolder: From a903cb99787cf42c728ca07e41cd3fd662db779f Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 15 Mar 2026 19:37:38 +0800 Subject: [PATCH 20/56] wip --- cookbook/rl/gkd_on_policy.py | 4 +-- src/twinkle/template/base.py | 2 +- src/twinkle/template/utils.py | 52 ++++++++++++++++------------------- 3 files changed, 26 insertions(+), 32 deletions(-) diff --git a/cookbook/rl/gkd_on_policy.py b/cookbook/rl/gkd_on_policy.py index d78a0e33..2513d1e1 100644 --- a/cookbook/rl/gkd_on_policy.py +++ b/cookbook/rl/gkd_on_policy.py @@ -56,8 +56,8 @@ logger = get_logger() # ── Configuration ───────────────────────────────────────────────────────────── -STUDENT_MODEL_ID = os.environ.get('STUDENT_MODEL_ID', 'ms://Qwen/Qwen2.5-1.5B-Instruct') -TEACHER_MODEL_ID = os.environ.get('TEACHER_MODEL_ID', 'ms://Qwen/Qwen3-4B') +STUDENT_MODEL_ID = os.environ.get('STUDENT_MODEL_ID', 'ms://Qwen/Qwen3-0.6B') +TEACHER_MODEL_ID = os.environ.get('TEACHER_MODEL_ID', 'ms://Qwen/Qwen3-8B') MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 2)) diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index 60c8b723..6a2055c5 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -9,7 +9,7 @@ from twinkle.data_format import InputFeature, Message, Trajectory from twinkle.hub import HubOperation from twinkle.utils import load_image, to_device -from .utils import tokenize_with_assistant_labels, transfer_to_standard_message, TokenizeByRound +from .utils import transfer_to_standard_message, TokenizeByRound if TYPE_CHECKING: import torch diff --git a/src/twinkle/template/utils.py b/src/twinkle/template/utils.py index 5462105f..e343b419 100644 --- a/src/twinkle/template/utils.py +++ b/src/twinkle/template/utils.py @@ -188,9 +188,8 @@ class TokenizeByRound: This approach handles tags correctly by encoding each message incrementally, determining token boundaries by comparing consecutive encode results. - Unlike TokenizeByPlaceHolder which uses dummy placeholders and may fail when - reasoning_content causes extra tags, this method directly encodes - the trajectory up to each message and compares lengths to find token ranges. + For assistant messages, uses add_generation_prompt=True to exclude the assistant + prefix (e.g., '<|im_start|>assistant\n') from training labels. """ @staticmethod @@ -202,13 +201,14 @@ def tokenize_with_assistant_labels( """Tokenize trajectory and generate labels for assistant turns. Args: - tokenizer: The tokenizer (used for decoding if needed). - encode_func: Function to encode a trajectory. + tokenizer: The tokenizer (unused, kept for interface compatibility). + encode_func: Function to encode a trajectory. Must support add_generation_prompt. trajectory: The trajectory containing messages. Returns: Tuple of (input_ids, labels, extra_encoded_fields). - Labels are -100 for non-assistant tokens, original token id for assistant tokens. + Labels are -100 for non-assistant tokens, original token id for assistant content tokens. + Assistant prefix tokens (e.g., '<|im_start|>assistant\n') are excluded from training. """ import torch messages = trajectory['messages'] @@ -225,35 +225,29 @@ def tokenize_with_assistant_labels( if not messages: return full_ids, labels, encoded - # Encode round by round to find token boundaries - # prev_len tracks where the previous message ended - prev_len = 0 - for i, msg in enumerate(messages): - # Create partial trajectory up to current message (inclusive) - partial_trajectory = copy(trajectory) - partial_trajectory['messages'] = list(messages[:i + 1]) + if msg['role'] != 'assistant': + continue - # Encode partial trajectory - partial_encoded = encode_func(partial_trajectory) - partial_ids = partial_encoded['input_ids'] + # Get position AFTER assistant prefix: + # encode(messages[:i], add_generation_prompt=True) includes the prefix + partial_trajectory = copy(trajectory) + partial_trajectory['messages'] = list(messages[:i]) + partial_ids = encode_func(partial_trajectory, add_generation_prompt=True)['input_ids'] if isinstance(partial_ids, torch.Tensor): partial_ids = partial_ids.tolist()[0] + start_pos = len(partial_ids) - curr_len = len(partial_ids) - - # If this is an assistant message, mark those tokens as trainable - if msg['role'] == 'assistant': - # Tokens from prev_len to curr_len belong to this assistant turn - for j in range(prev_len, min(curr_len, len(full_ids))): - labels[j] = full_ids[j] - - prev_len = curr_len + # Get end position: encode(messages[:i+1]) includes full assistant turn + partial_trajectory = copy(trajectory) + partial_trajectory['messages'] = list(messages[:i + 1]) + partial_ids = encode_func(partial_trajectory)['input_ids'] + if isinstance(partial_ids, torch.Tensor): + partial_ids = partial_ids.tolist()[0] + end_pos = len(partial_ids) - # Handle any remaining tokens after the last message (e.g., EOS added by full encode) - # If the last message was assistant, these trailing tokens should also be trainable - if messages and messages[-1]['role'] == 'assistant': - for j in range(prev_len, len(full_ids)): + # Mark assistant CONTENT tokens as trainable (excluding prefix) + for j in range(start_pos, min(end_pos, len(full_ids))): labels[j] = full_ids[j] return full_ids, labels, encoded From 926210cfb8e00833420074839b8991cf4977c26a Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 15 Mar 2026 21:16:07 +0800 Subject: [PATCH 21/56] wip --- cookbook/rl/gkd_on_policy.py | 8 +++++--- src/twinkle/loss/gkd.py | 9 ++------- src/twinkle/metric/loss.py | 6 +++--- src/twinkle/model/transformers/transformers.py | 4 +++- 4 files changed, 13 insertions(+), 14 deletions(-) diff --git a/cookbook/rl/gkd_on_policy.py b/cookbook/rl/gkd_on_policy.py index 2513d1e1..fb9d429b 100644 --- a/cookbook/rl/gkd_on_policy.py +++ b/cookbook/rl/gkd_on_policy.py @@ -70,7 +70,7 @@ GKD_BETA = float(os.environ.get('GKD_BETA', 0.5)) GKD_TEMPERATURE = float(os.environ.get('GKD_TEMPERATURE', 1.0)) -GKD_TOPK = int(os.environ.get('GKD_TOPK', 10)) +GKD_TOPK = int(os.environ.get('GKD_TOPK', 20)) ADAPTER_NAME = 'default' @@ -208,14 +208,16 @@ def main(): # 1. Student vLLM generates completions sample_response = student_sampler.sample(batch, SamplingParams(max_tokens=MAX_NEW_TOKENS, temperature=1.0, num_samples=1)) input_data = [seq.new_input_feature for response in sample_response for seq in response.sequences] + input_ids_list = [] for data in input_data: - data.pop('input_ids', None) + input_ids_list.append(data.pop('input_ids', None)) # 2. Teacher vLLM computes top-k prompt logprobs on generated sequences teacher_response = teacher_sampler.sample( input_data, - SamplingParams(max_tokens=1, temperature=1.0, prompt_logprobs=10), + SamplingParams(max_tokens=1, temperature=1.0, prompt_logprobs=GKD_TOPK), ) + teacher_input_ids = teacher_response[0].sequences[0].new_input_feature['input_ids'] # 3. Convert teacher logprobs to tensor format for GKDLoss # teacher_response is List[SampleResponse], extract topk_prompt_logprobs from each diff --git a/src/twinkle/loss/gkd.py b/src/twinkle/loss/gkd.py index 93fbb9de..a46e4ffd 100644 --- a/src/twinkle/loss/gkd.py +++ b/src/twinkle/loss/gkd.py @@ -89,15 +89,12 @@ def __call__( teacher_topk_logprobs = teacher_topk_logprobs[:, :student_logits.shape[1]] teacher_topk_indices = teacher_topk_indices[:, :student_logits.shape[1]] - # Shift labels: label[i] = next token predicted by logits[i] - # The last position wraps to label[0] via roll; since label[0] is -100 (prompt), - # it will be correctly excluded by the mask in _generalized_jsd_loss. - shifted_labels = labels.roll(shifts=-1, dims=1) + shifted_labels = labels loss = self._generalized_jsd_loss( student_logits=student_logits, teacher_logits=teacher_logits, - labels=shifted_labels, + labels=labels, beta=self.beta, temperature=self.temperature, chunk_size=self.chunk_size, @@ -178,8 +175,6 @@ def _generalized_jsd_loss( elif stu_dim > tea_dim: teacher_logits = F.pad(teacher_logits, (0, stu_dim - tea_dim)) teacher_logits[..., tea_dim:] = student_logits[..., tea_dim:] - if student_logits.shape[1] != mask.shape[1]: - breakpoint() student_logits = student_logits[mask] # [num_valid, vocab/topk] teacher_logits = teacher_logits[mask] num_valid = mask.sum() diff --git a/src/twinkle/metric/loss.py b/src/twinkle/metric/loss.py index b15f1f96..52f50fdd 100644 --- a/src/twinkle/metric/loss.py +++ b/src/twinkle/metric/loss.py @@ -13,19 +13,19 @@ class LossMetric(Metric): process_group: The process group to collect data from """ - def __init__(self, device_mesh, process_group, loss_reduction='mean', **kwargs): + def __init__(self, device_mesh, process_group, **kwargs): super().__init__(device_mesh, process_group, **kwargs) self.total_loss = 0 self.total_count = 0 self.grad_norm = 0 self.num_tokens = 0 - self.loss_reduction = loss_reduction def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: ModelOutput, **kwargs): if 'loss' not in outputs: return loss = outputs['loss'] - if self.loss_reduction == 'sum': + loss_reduction = kwargs.get('loss_reduction', 'mean') + if loss_reduction == 'sum': if not isinstance(inputs, list): inputs = [inputs] for input in inputs: diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 82997d19..908b2d9a 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -128,7 +128,9 @@ def accumulate_metrics(self, is_training): lr=self._get_lr(), step=self.cur_step - 1, gradient_accumulation_steps=self.gradient_accumulation_steps, - grad_norm=self._last_grad_norm) + grad_norm=self._last_grad_norm, + loss_reduction=getattr(self.loss_instance, 'reduction', 'mean') + ) def calculate_metrics(self, is_training): self.accumulate_metrics(is_training) From c49fccd2be905606aa3b0eeebd124dd426688a85 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Mon, 16 Mar 2026 11:13:22 +0800 Subject: [PATCH 22/56] wip --- cookbook/rl/gkd_on_policy.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cookbook/rl/gkd_on_policy.py b/cookbook/rl/gkd_on_policy.py index fb9d429b..8d234c24 100644 --- a/cookbook/rl/gkd_on_policy.py +++ b/cookbook/rl/gkd_on_policy.py @@ -111,9 +111,7 @@ def convert_topk_prompt_logprobs( seq_indices = [] for pos_topk in seq_topk: if pos_topk is None: - # First position typically has no logprobs - seq_logprobs.append([0.0] * len(seq_topk[1]) if len(seq_topk) > 1 and seq_topk[1] else [0.0]) - seq_indices.append([0] * len(seq_topk[1]) if len(seq_topk) > 1 and seq_topk[1] else [0]) + continue else: seq_logprobs.append([lp for _, lp in pos_topk]) seq_indices.append([tid for tid, _ in pos_topk]) From 1e7240f0d5f0063554c764b315f42aa55646e32d Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Mon, 16 Mar 2026 11:17:44 +0800 Subject: [PATCH 23/56] fix --- cookbook/rl/gkd_off_policy.py | 123 ++++++++++++++++++++++++---------- 1 file changed, 89 insertions(+), 34 deletions(-) diff --git a/cookbook/rl/gkd_off_policy.py b/cookbook/rl/gkd_off_policy.py index 4704e2ea..047f6c69 100644 --- a/cookbook/rl/gkd_off_policy.py +++ b/cookbook/rl/gkd_off_policy.py @@ -5,56 +5,63 @@ Pipeline: 1. DataLoader supplies full-text batches (prompt + reference answer). - 2. Teacher TransformersModel runs forward_only() to get frozen logits. + 2. Teacher vLLM sampler computes top-k prompt logprobs on the sequences. 3. Student TransformersModel runs forward_backward() with GKDLoss. Key difference from on-policy: - - No vLLM sampler needed (responses already in the dataset). - - Simpler GPU layout: all GPUs can go to the model group. + - No student sampler needed (responses already in the dataset). - Faster per-step (no generation latency), but less exploration. Architecture (Ray): ┌─────────────────────────────────────────────────────────────────┐ │ Driver (CPU) │ │ dataloader ──► full-text batch (prompt + reference answer) │ - │ teacher_model.forward_only() ──► frozen teacher logits │ - │ student_model.forward_backward(teacher_logits=...) ──► GKD │ + │ teacher_sampler.sample(prompt_logprobs=k) ──► teacher lps │ + │ student_model.forward_backward(teacher_output=...) ──► GKD │ └─────────────────────────────────────────────────────────────────┘ │ - TransformersModel ×2 - student + teacher (all GPUs) + vLLMSampler + TransformersModel + (teacher) (student) Environment variables (all optional): - STUDENT_MODEL_ID – (default: ms://Qwen/Qwen2.5-1.5B-Instruct) - TEACHER_MODEL_ID – (default: ms://Qwen/Qwen2.5-7B-Instruct) - NUM_GPUS – total GPUs for both models (default: 4) - BATCH_SIZE – global batch size (default: 8) - MAX_STEPS – total optimisation steps (default: 200) - LR – learning rate (default: 1e-4) + STUDENT_MODEL_ID – (default: ms://Qwen/Qwen3-0.6B) + TEACHER_MODEL_ID – (default: ms://Qwen/Qwen3-8B) + MODEL_GPUS – GPUs for student model (default: 4) + SAMPLER_GPUS – GPUs for teacher vLLM sampler (default: 2) + BATCH_SIZE – global batch size (default: 8) + MAX_STEPS – total optimisation steps (default: 200) + LR – learning rate (default: 1e-4) GKD_BETA – JSD beta (0=fwd KL, 1=rev KL) (default: 0.5) GKD_TEMPERATURE – distillation temperature (default: 1.0) - GKD_TOPK – top-k vocab reduction; 0=full (default: 0) + GKD_TOPK – top-k vocab for teacher logprobs (default: 20) """ import os +from typing import List, Optional +import torch from peft import LoraConfig import twinkle from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger +from twinkle.data_format import SamplingParams from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.loss import GKDLoss from twinkle.model import TransformersModel from twinkle.preprocessor import GSM8KFullProcessor +from twinkle.sampler import vLLMSampler +from twinkle.template import Template logger = get_logger() # ── Configuration ───────────────────────────────────────────────────────────── -STUDENT_MODEL_ID = os.environ.get('STUDENT_MODEL_ID', 'ms://Qwen/Qwen2.5-1.5B-Instruct') -TEACHER_MODEL_ID = os.environ.get('TEACHER_MODEL_ID', 'ms://Qwen/Qwen2.5-7B-Instruct') +STUDENT_MODEL_ID = os.environ.get('STUDENT_MODEL_ID', 'ms://Qwen/Qwen3-0.6B') +TEACHER_MODEL_ID = os.environ.get('TEACHER_MODEL_ID', 'ms://Qwen/Qwen3-8B') -NUM_GPUS = int(os.environ.get('NUM_GPUS', 4)) +MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) +SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 2)) +NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) MAX_STEPS = int(os.environ.get('MAX_STEPS', 200)) @@ -62,7 +69,7 @@ GKD_BETA = float(os.environ.get('GKD_BETA', 0.5)) GKD_TEMPERATURE = float(os.environ.get('GKD_TEMPERATURE', 1.0)) -GKD_TOPK = int(os.environ.get('GKD_TOPK', 0)) +GKD_TOPK = int(os.environ.get('GKD_TOPK', 20)) ADAPTER_NAME = 'default' @@ -78,19 +85,57 @@ def create_dataset(): return dataset +# ── Utility ─────────────────────────────────────────────────────────────────── + +def convert_topk_prompt_logprobs( + topk_prompt_logprobs_batch: List[List[Optional[List[tuple]]]], +) -> dict: + """Convert vLLM topk_prompt_logprobs to GKDLoss teacher_output format.""" + batch_logprobs = [] + batch_indices = [] + + for seq_topk in topk_prompt_logprobs_batch: + seq_logprobs = [] + seq_indices = [] + for pos_topk in seq_topk: + if pos_topk is None: + continue + else: + seq_logprobs.append([lp for _, lp in pos_topk]) + seq_indices.append([tid for tid, _ in pos_topk]) + batch_logprobs.append(seq_logprobs) + batch_indices.append(seq_indices) + + # Pad to same seq_len within batch + max_len = max(len(seq) for seq in batch_logprobs) + topk = len(batch_logprobs[0][0]) if batch_logprobs and batch_logprobs[0] else GKD_TOPK + + for i in range(len(batch_logprobs)): + pad_len = max_len - len(batch_logprobs[i]) + if pad_len > 0: + batch_logprobs[i].extend([[0.0] * topk] * pad_len) + batch_indices[i].extend([[0] * topk] * pad_len) + + return { + 'teacher_topk_logprobs': torch.tensor(batch_logprobs, dtype=torch.float32), + 'teacher_topk_indices': torch.tensor(batch_indices, dtype=torch.long), + } + + # ── Training ────────────────────────────────────────────────────────────────── def main(): device_groups = [ - DeviceGroup(name='model', ranks=list(range(NUM_GPUS)), device_type='cuda'), + DeviceGroup(name='student_model', ranks=MODEL_GPUS, device_type='cuda'), + DeviceGroup(name='teacher_sampler', ranks=SAMPLER_GPUS, device_type='cuda'), ] - model_mesh = DeviceMesh.from_sizes(world_size=NUM_GPUS, dp_size=NUM_GPUS) + model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS) + sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS) twinkle.initialize( mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, - lazy_collect=False, ) logger.info(get_device_placement()) @@ -98,7 +143,7 @@ def main(): student_model = TransformersModel( model_id=STUDENT_MODEL_ID, device_mesh=model_mesh, - remote_group='model', + remote_group='student_model', ) student_model.add_adapter_to_model( ADAPTER_NAME, @@ -110,13 +155,14 @@ def main(): student_model.set_loss(GKDLoss(beta=GKD_BETA, temperature=GKD_TEMPERATURE)) student_model.set_template('Template', model_id=STUDENT_MODEL_ID) - # ── Teacher model (frozen, for logits) ───────────────────────────────────── - teacher_model = TransformersModel( + # ── Teacher vLLM sampler (for prompt logprobs) ───────────────────────────── + teacher_sampler = vLLMSampler( model_id=TEACHER_MODEL_ID, - device_mesh=model_mesh, - remote_group='model', + engine_args={'gpu_memory_utilization': 0.85, 'max_model_len': 2048, 'logprobs_mode': 'raw_logprobs'}, + device_mesh=sampler_mesh, + remote_group='teacher_sampler', ) - teacher_model.set_template('Template', model_id=TEACHER_MODEL_ID) + teacher_sampler.set_template(Template, model_id=TEACHER_MODEL_ID) # ── DataLoader (full-text: prompt + reference answer) ────────────────────── dataloader = DataLoader( @@ -124,11 +170,9 @@ def main(): batch_size=BATCH_SIZE, min_batch_size=BATCH_SIZE, device_mesh=model_mesh, - remote_group='model', + remote_group='student_model', ) - topk = GKD_TOPK if GKD_TOPK > 0 else None - logger.info(f'GKD Off-Policy | student={STUDENT_MODEL_ID} teacher={TEACHER_MODEL_ID}') logger.info(f' beta={GKD_BETA} T={GKD_TEMPERATURE} topk={GKD_TOPK}') @@ -139,12 +183,23 @@ def main(): input_data = batch if isinstance(batch, list) else [batch] - # Teacher logits (frozen) - teacher_output = teacher_model.forward_only(inputs=input_data) - teacher_logits = teacher_output.get('logits') + # Remove input_ids so teacher re-encodes with its own tokenizer + for data in input_data: + data.pop('input_ids', None) + + # Teacher vLLM computes top-k prompt logprobs on the reference sequences + teacher_response = teacher_sampler.sample( + input_data, + SamplingParams(max_tokens=1, temperature=1.0, prompt_logprobs=GKD_TOPK), + ) + + # Convert teacher logprobs to tensor format for GKDLoss + teacher_output = convert_topk_prompt_logprobs( + [resp.topk_prompt_logprobs for resp in teacher_response], + ) # Student forward + GKD backward - student_model.forward_backward(inputs=input_data, teacher_logits=teacher_logits, topk=topk) + student_model.forward_backward(inputs=input_data, **teacher_output) student_model.clip_grad_and_step() optim_step += 1 From 29dc7ac7ad26eaf018d8226239d300e64a8a3aa7 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Mon, 16 Mar 2026 15:42:17 +0800 Subject: [PATCH 24/56] wip --- cookbook/rl/gkd_off_policy.py | 19 +++++++++---------- cookbook/rl/gkd_on_policy.py | 8 +++----- .../sampler/vllm_sampler/vllm_engine.py | 18 ++++++++++-------- 3 files changed, 22 insertions(+), 23 deletions(-) diff --git a/cookbook/rl/gkd_off_policy.py b/cookbook/rl/gkd_off_policy.py index 047f6c69..56d6b766 100644 --- a/cookbook/rl/gkd_off_policy.py +++ b/cookbook/rl/gkd_off_policy.py @@ -70,6 +70,7 @@ GKD_BETA = float(os.environ.get('GKD_BETA', 0.5)) GKD_TEMPERATURE = float(os.environ.get('GKD_TEMPERATURE', 1.0)) GKD_TOPK = int(os.environ.get('GKD_TOPK', 20)) +MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 1024)) ADAPTER_NAME = 'default' @@ -81,7 +82,6 @@ def create_dataset(): dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train')) dataset.set_template('Template', model_id=STUDENT_MODEL_ID, max_length=1024) dataset.map(GSM8KFullProcessor()) - dataset.encode() return dataset @@ -181,21 +181,20 @@ def main(): if optim_step >= MAX_STEPS: break - input_data = batch if isinstance(batch, list) else [batch] - - # Remove input_ids so teacher re-encodes with its own tokenizer - for data in input_data: - data.pop('input_ids', None) - # Teacher vLLM computes top-k prompt logprobs on the reference sequences teacher_response = teacher_sampler.sample( - input_data, - SamplingParams(max_tokens=1, temperature=1.0, prompt_logprobs=GKD_TOPK), + batch, + SamplingParams(max_tokens=MAX_NEW_TOKENS, temperature=1.0, prompt_logprobs=1, logprobs=GKD_TOPK), ) + input_data = [seq.new_input_feature for response in teacher_response for seq in response.sequences] + input_ids_list = [] + for data in input_data: + input_ids_list.append(data.pop('input_ids', None)) # Convert teacher logprobs to tensor format for GKDLoss teacher_output = convert_topk_prompt_logprobs( - [resp.topk_prompt_logprobs for resp in teacher_response], + [resp.prompt_logprobs for resp in teacher_response], + [[sequence.logprobs for sequence in resp.sequences] for resp in teacher_response], ) # Student forward + GKD backward diff --git a/cookbook/rl/gkd_on_policy.py b/cookbook/rl/gkd_on_policy.py index 8d234c24..73e0448b 100644 --- a/cookbook/rl/gkd_on_policy.py +++ b/cookbook/rl/gkd_on_policy.py @@ -67,6 +67,7 @@ BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) MAX_STEPS = int(os.environ.get('MAX_STEPS', 200)) LEARNING_RATE = float(os.environ.get('LR', 1e-4)) +N_SAMPLES = int(os.environ.get('N_SAMPLES', 8)) GKD_BETA = float(os.environ.get('GKD_BETA', 0.5)) GKD_TEMPERATURE = float(os.environ.get('GKD_TEMPERATURE', 1.0)) @@ -82,7 +83,6 @@ def create_dataset(): dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train')) dataset.set_template('Template', model_id=STUDENT_MODEL_ID, max_length=1024) dataset.map(GSM8KProcessor()) - dataset.encode(add_generation_prompt=True) return dataset @@ -204,18 +204,16 @@ def main(): break # 1. Student vLLM generates completions - sample_response = student_sampler.sample(batch, SamplingParams(max_tokens=MAX_NEW_TOKENS, temperature=1.0, num_samples=1)) + sample_response = student_sampler.sample(batch, SamplingParams(max_tokens=MAX_NEW_TOKENS, temperature=1.0, num_samples=N_SAMPLES)) input_data = [seq.new_input_feature for response in sample_response for seq in response.sequences] - input_ids_list = [] for data in input_data: - input_ids_list.append(data.pop('input_ids', None)) + data.pop('input_ids', None) # 2. Teacher vLLM computes top-k prompt logprobs on generated sequences teacher_response = teacher_sampler.sample( input_data, SamplingParams(max_tokens=1, temperature=1.0, prompt_logprobs=GKD_TOPK), ) - teacher_input_ids = teacher_response[0].sequences[0].new_input_feature['input_ids'] # 3. Convert teacher logprobs to tensor format for GKDLoss # teacher_response is List[SampleResponse], extract topk_prompt_logprobs from each diff --git a/src/twinkle/sampler/vllm_sampler/vllm_engine.py b/src/twinkle/sampler/vllm_sampler/vllm_engine.py index 66f01fe8..ffc1398b 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_engine.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_engine.py @@ -217,6 +217,7 @@ async def sample( if isinstance(sampling_params, dict): sampling_params = SamplingParams.from_dict(sampling_params) prompt_logprobs_k = sampling_params.prompt_logprobs or 0 + logprobs = sampling_params.logprobs or 0 vllm_params = sampling_params.to_vllm(**kwargs) # Build request @@ -276,8 +277,11 @@ async def sample( if output.logprobs is not None: seq_logprobs = [] for i, lp in enumerate(output.logprobs): - if i < len(token_ids) and token_ids[i] in lp: - seq_logprobs.append(lp[token_ids[i]].logprob) + if i < len(token_ids): + sorted_items = sorted( + lp.items(), key=lambda x: -(x[1].logprob))[:logprobs] + seq_logprobs.append([(tid, lp_obj.logprob) + for tid, lp_obj in sorted_items]) # Map finish_reason to StopReason stop_reason: StopReason = 'length' @@ -299,8 +303,7 @@ async def sample( for i, lp_dict in enumerate(result.prompt_logprobs): if lp_dict is None: - result_prompt_logprobs.append(None) - result_topk_prompt_logprobs.append(None) + assert i == 0, 'Postion > 0 should not has None lobprobs!' continue # Get logprob for the actual token @@ -308,7 +311,7 @@ async def sample( token_id = prompt_token_ids[i] if token_id in lp_dict: lp_obj = lp_dict[token_id] - result_prompt_logprobs.append(lp_obj.logprob if hasattr(lp_obj, 'logprob') else lp_obj) + result_prompt_logprobs.append(lp_obj.logprob) else: result_prompt_logprobs.append(None) else: @@ -316,9 +319,8 @@ async def sample( # Get top-k logprobs sorted_items = sorted( - lp_dict.items(), key=lambda x: -(x[1].logprob - if hasattr(x[1], 'logprob') else x[1]))[:prompt_logprobs_k] - result_topk_prompt_logprobs.append([(tid, lp_obj.logprob if hasattr(lp_obj, 'logprob') else lp_obj) + lp_dict.items(), key=lambda x: -(x[1].logprob))[:prompt_logprobs_k] + result_topk_prompt_logprobs.append([(tid, lp_obj.logprob) for tid, lp_obj in sorted_items]) return SampleResponse( From 36a0eb21e94cdedbfaad93a17f062d4c98e5b35b Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Tue, 17 Mar 2026 11:51:23 +0800 Subject: [PATCH 25/56] wip --- cookbook/rl/gkd_off_policy.py | 90 ++++++++++++------- cookbook/rl/gkd_on_policy.py | 12 ++- cookbook/transformers/fsdp2.py | 2 +- src/twinkle/preprocessor/__init__.py | 2 +- src/twinkle/preprocessor/llm.py | 23 ----- .../sampler/vllm_sampler/vllm_engine.py | 4 +- .../sampler/vllm_sampler/vllm_sampler.py | 4 +- 7 files changed, 70 insertions(+), 67 deletions(-) diff --git a/cookbook/rl/gkd_off_policy.py b/cookbook/rl/gkd_off_policy.py index 56d6b766..89a0c68d 100644 --- a/cookbook/rl/gkd_off_policy.py +++ b/cookbook/rl/gkd_off_policy.py @@ -49,29 +49,29 @@ from twinkle.dataset import Dataset, DatasetMeta from twinkle.loss import GKDLoss from twinkle.model import TransformersModel -from twinkle.preprocessor import GSM8KFullProcessor +from twinkle.preprocessor import GSM8KProcessor from twinkle.sampler import vLLMSampler from twinkle.template import Template logger = get_logger() # ── Configuration ───────────────────────────────────────────────────────────── -STUDENT_MODEL_ID = os.environ.get('STUDENT_MODEL_ID', 'ms://Qwen/Qwen3-0.6B') -TEACHER_MODEL_ID = os.environ.get('TEACHER_MODEL_ID', 'ms://Qwen/Qwen3-8B') +STUDENT_MODEL_ID = os.environ.get('STUDENT_MODEL_ID', 'ms://Qwen/Qwen3.5-2B') +TEACHER_MODEL_ID = os.environ.get('TEACHER_MODEL_ID', 'ms://Qwen/Qwen3.5-9B') -MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) -SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 2)) +MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 8)) +SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 8)) NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) -MAX_STEPS = int(os.environ.get('MAX_STEPS', 200)) -LEARNING_RATE = float(os.environ.get('LR', 1e-4)) +MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000)) +LEARNING_RATE = float(os.environ.get('LR', 5e-5)) GKD_BETA = float(os.environ.get('GKD_BETA', 0.5)) GKD_TEMPERATURE = float(os.environ.get('GKD_TEMPERATURE', 1.0)) GKD_TOPK = int(os.environ.get('GKD_TOPK', 20)) -MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 1024)) - +MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 2048)) +N_SAMPLES = int(os.environ.get('N_SAMPLES', 1)) ADAPTER_NAME = 'default' @@ -81,33 +81,63 @@ def create_dataset(): """Full-text dataset with prompt + reference answer for off-policy distillation.""" dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train')) dataset.set_template('Template', model_id=STUDENT_MODEL_ID, max_length=1024) - dataset.map(GSM8KFullProcessor()) + dataset.map(GSM8KProcessor()) return dataset # ── Utility ─────────────────────────────────────────────────────────────────── def convert_topk_prompt_logprobs( - topk_prompt_logprobs_batch: List[List[Optional[List[tuple]]]], + prompt_logprobs_batch: List[Optional[List[List[tuple]]]], + sequences_logprobs_batch: List[List[Optional[List[List[tuple]]]]], ) -> dict: - """Convert vLLM topk_prompt_logprobs to GKDLoss teacher_output format.""" + """Convert vLLM topk_prompt_logprobs to GKDLoss teacher_output format. + + Args: + prompt_logprobs_batch: [batch] each is topk_prompt_logprobs for one request. + Shape: [prompt_seq_len, topk] per request. + sequences_logprobs_batch: [batch][n_samples] each is generated logprobs. + Shape: [generated_len, topk] per sequence. + + Returns: + Dict with expanded teacher logprobs/indices tensors. + Each prompt is expanded N times (one per generated sequence). + """ batch_logprobs = [] batch_indices = [] - for seq_topk in topk_prompt_logprobs_batch: - seq_logprobs = [] - seq_indices = [] - for pos_topk in seq_topk: - if pos_topk is None: - continue - else: - seq_logprobs.append([lp for _, lp in pos_topk]) - seq_indices.append([tid for tid, _ in pos_topk]) - batch_logprobs.append(seq_logprobs) - batch_indices.append(seq_indices) + for prompt_logprobs, sequences_logprobs in zip(prompt_logprobs_batch, sequences_logprobs_batch): + n_samples = len(sequences_logprobs) + + # Parse prompt logprobs (shared across all sequences) + # prompt_logprobs is List[float], expand to [seq_len, topk] with padding + prompt_lps = [] + prompt_ids = [] + if prompt_logprobs is not None: + for lp in prompt_logprobs: + if lp is None: + lp = -1 + # Expand single logprob to topk slots: [lp, 0, 0, ...] + prompt_lps.append([lp] + [0.0] * (GKD_TOPK - 1)) + prompt_ids.append([0] * GKD_TOPK) + + # Expand prompt and concat with each sequence's generated logprobs + for seq_logprobs in sequences_logprobs: + # Start with prompt logprobs (copy for each sequence) + seq_lps = list(prompt_lps) + seq_ids = list(prompt_ids) + + # Append generated token logprobs + if seq_logprobs is not None: + for pos_topk in seq_logprobs: + seq_lps.append([lp for _, lp in pos_topk]) + seq_ids.append([tid for tid, _ in pos_topk]) + + batch_logprobs.append(seq_lps) + batch_indices.append(seq_ids) # Pad to same seq_len within batch - max_len = max(len(seq) for seq in batch_logprobs) + max_len = max(len(seq) for seq in batch_logprobs) if batch_logprobs else 1 topk = len(batch_logprobs[0][0]) if batch_logprobs and batch_logprobs[0] else GKD_TOPK for i in range(len(batch_logprobs)): @@ -116,9 +146,10 @@ def convert_topk_prompt_logprobs( batch_logprobs[i].extend([[0.0] * topk] * pad_len) batch_indices[i].extend([[0] * topk] * pad_len) + # In vllm output, the first position is None, we returns an invalid value(-10000), so roll it to match the labels return { - 'teacher_topk_logprobs': torch.tensor(batch_logprobs, dtype=torch.float32), - 'teacher_topk_indices': torch.tensor(batch_indices, dtype=torch.long), + 'teacher_topk_logprobs': torch.roll(torch.tensor(batch_logprobs, dtype=torch.float32), shifts=-1, dims=1), + 'teacher_topk_indices': torch.roll(torch.tensor(batch_indices, dtype=torch.long), shifts=-1, dims=1), } @@ -158,7 +189,7 @@ def main(): # ── Teacher vLLM sampler (for prompt logprobs) ───────────────────────────── teacher_sampler = vLLMSampler( model_id=TEACHER_MODEL_ID, - engine_args={'gpu_memory_utilization': 0.85, 'max_model_len': 2048, 'logprobs_mode': 'raw_logprobs'}, + engine_args={'gpu_memory_utilization': 0.85, 'max_model_len': 10240, 'logprobs_mode': 'raw_logprobs'}, device_mesh=sampler_mesh, remote_group='teacher_sampler', ) @@ -184,12 +215,9 @@ def main(): # Teacher vLLM computes top-k prompt logprobs on the reference sequences teacher_response = teacher_sampler.sample( batch, - SamplingParams(max_tokens=MAX_NEW_TOKENS, temperature=1.0, prompt_logprobs=1, logprobs=GKD_TOPK), + SamplingParams(max_tokens=MAX_NEW_TOKENS, temperature=1.0, prompt_logprobs=1, logprobs=GKD_TOPK, num_samples=N_SAMPLES), ) input_data = [seq.new_input_feature for response in teacher_response for seq in response.sequences] - input_ids_list = [] - for data in input_data: - input_ids_list.append(data.pop('input_ids', None)) # Convert teacher logprobs to tensor format for GKDLoss teacher_output = convert_topk_prompt_logprobs( diff --git a/cookbook/rl/gkd_on_policy.py b/cookbook/rl/gkd_on_policy.py index 73e0448b..e0788ddf 100644 --- a/cookbook/rl/gkd_on_policy.py +++ b/cookbook/rl/gkd_on_policy.py @@ -90,7 +90,6 @@ def create_dataset(): def convert_topk_prompt_logprobs( topk_prompt_logprobs_batch: List[List[Optional[List[tuple]]]], - device: str = 'cpu', ) -> dict: """Convert vLLM topk_prompt_logprobs to GKDLoss teacher_output format. @@ -111,7 +110,8 @@ def convert_topk_prompt_logprobs( seq_indices = [] for pos_topk in seq_topk: if pos_topk is None: - continue + seq_logprobs.append([0.0] * len(seq_topk[1]) if len(seq_topk) > 1 and seq_topk[1] else [0.0]) + seq_indices.append([0] * len(seq_topk[1]) if len(seq_topk) > 1 and seq_topk[1] else [0]) else: seq_logprobs.append([lp for _, lp in pos_topk]) seq_indices.append([tid for tid, _ in pos_topk]) @@ -120,7 +120,7 @@ def convert_topk_prompt_logprobs( # Pad to same seq_len within batch max_len = max(len(seq) for seq in batch_logprobs) - topk = len(batch_logprobs[0][0]) if batch_logprobs and batch_logprobs[0] else GKD_TOPK + topk = GKD_TOPK for i in range(len(batch_logprobs)): pad_len = max_len - len(batch_logprobs[i]) @@ -129,8 +129,8 @@ def convert_topk_prompt_logprobs( batch_indices[i].extend([[0] * topk] * pad_len) return { - 'teacher_topk_logprobs': torch.tensor(batch_logprobs, dtype=torch.float32, device=device), - 'teacher_topk_indices': torch.tensor(batch_indices, dtype=torch.long, device=device), + 'teacher_topk_logprobs': torch.roll(torch.tensor(batch_logprobs, dtype=torch.float32), shifts=-1, dims=1), + 'teacher_topk_indices': torch.roll(torch.tensor(batch_indices, dtype=torch.long), shifts=-1, dims=1), } @@ -206,8 +206,6 @@ def main(): # 1. Student vLLM generates completions sample_response = student_sampler.sample(batch, SamplingParams(max_tokens=MAX_NEW_TOKENS, temperature=1.0, num_samples=N_SAMPLES)) input_data = [seq.new_input_feature for response in sample_response for seq in response.sequences] - for data in input_data: - data.pop('input_ids', None) # 2. Teacher vLLM computes top-k prompt logprobs on generated sequences teacher_response = teacher_sampler.sample( diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py index e0f67537..9df3d5cf 100644 --- a/cookbook/transformers/fsdp2.py +++ b/cookbook/transformers/fsdp2.py @@ -10,7 +10,7 @@ from twinkle.preprocessor import SelfCognitionProcessor # Construct a device_mesh, fsdp=4, dp=2 -device_mesh = DeviceMesh.from_sizes(fsdp_size=4, dp_size=2) +device_mesh = DeviceMesh.from_sizes(dp_size=2) # use torchrun mode twinkle.initialize(mode='local', global_device_mesh=device_mesh) diff --git a/src/twinkle/preprocessor/__init__.py b/src/twinkle/preprocessor/__init__.py index 7234a60a..13b52d99 100644 --- a/src/twinkle/preprocessor/__init__.py +++ b/src/twinkle/preprocessor/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .base import DataFilter, Preprocessor from .llm import (AlpacaProcessor, CompetitionMathGRPOProcessor, CompetitionMathProcessor, CountdownProcessor, - GSM8KFullProcessor, GSM8KProcessor, SelfCognitionProcessor) + GSM8KProcessor, SelfCognitionProcessor) diff --git a/src/twinkle/preprocessor/llm.py b/src/twinkle/preprocessor/llm.py index 565f1661..3a1c3f11 100644 --- a/src/twinkle/preprocessor/llm.py +++ b/src/twinkle/preprocessor/llm.py @@ -154,26 +154,3 @@ def preprocess(self, row) -> Trajectory: user_data=[('ground_truth', ground_truth)], ) - -class GSM8KFullProcessor(GSM8KProcessor): - """GSM8K preprocessor that includes the reference answer as the assistant message. - - Produces a full Trajectory (prompt + reference answer) suitable for - off-policy knowledge distillation: the student and teacher both see the - ground-truth response text, and labels cover the response tokens. - """ - - def preprocess(self, row) -> Trajectory: - question = row['question'] - answer = row.get('answer', '') - ground_truth = self.extract_ground_truth(answer) - - messages = [ - Message(role='system', content=self.system_prompt), - Message(role='user', content=question), - Message(role='assistant', content=answer), - ] - return Trajectory( - messages=messages, - user_data=[('ground_truth', ground_truth)], - ) diff --git a/src/twinkle/sampler/vllm_sampler/vllm_engine.py b/src/twinkle/sampler/vllm_sampler/vllm_engine.py index ffc1398b..076d5f48 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_engine.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_engine.py @@ -303,7 +303,8 @@ async def sample( for i, lp_dict in enumerate(result.prompt_logprobs): if lp_dict is None: - assert i == 0, 'Postion > 0 should not has None lobprobs!' + result_prompt_logprobs.append(None) + result_topk_prompt_logprobs.append(None) continue # Get logprob for the actual token @@ -322,7 +323,6 @@ async def sample( lp_dict.items(), key=lambda x: -(x[1].logprob))[:prompt_logprobs_k] result_topk_prompt_logprobs.append([(tid, lp_obj.logprob) for tid, lp_obj in sorted_items]) - return SampleResponse( sequences=sequences, prompt_logprobs=result_prompt_logprobs, diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py index e33e178a..10e64b39 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py @@ -184,14 +184,14 @@ def encode_trajectory_for_vllm(self, trajectory: Trajectory, adapter_name: str = if hasattr(input_ids, 'tolist'): input_ids = input_ids.tolist() - result = InputFeature(input_ids=input_ids) + result = trajectory + result['input_ids'] = input_ids # Attach preprocessed images/videos for vLLM if images: result['images'] = images if videos: result['videos'] = videos - return result async def _sample_single( From 488ea433841ea581bc505dc35b278ee3fe669efb Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Tue, 17 Mar 2026 14:51:35 +0800 Subject: [PATCH 26/56] wip --- src/twinkle/utils/parallel.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/twinkle/utils/parallel.py b/src/twinkle/utils/parallel.py index 6dde4871..77504bd7 100644 --- a/src/twinkle/utils/parallel.py +++ b/src/twinkle/utils/parallel.py @@ -18,7 +18,10 @@ def _sanitize_lock_name(name: str) -> str: def acquire_lock(lock: FileLock, blocking: bool): try: - lock.acquire(blocking=blocking) + if 'blocking' in inspect.signature(AsyncEngineArgs).parameters: + lock.acquire(blocking=blocking) + else: + lock.acquire(timeout=(0 if not blocking else None)) return True except TimeoutError: return False From e4b931a43062564e4ee2937fcf78a43cd6b89541 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Tue, 17 Mar 2026 15:05:08 +0800 Subject: [PATCH 27/56] wip --- src/twinkle/utils/parallel.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/twinkle/utils/parallel.py b/src/twinkle/utils/parallel.py index 77504bd7..3f0287e8 100644 --- a/src/twinkle/utils/parallel.py +++ b/src/twinkle/utils/parallel.py @@ -1,6 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import os import re +import inspect from contextlib import contextmanager from datasets.utils.filelock import FileLock @@ -18,7 +19,7 @@ def _sanitize_lock_name(name: str) -> str: def acquire_lock(lock: FileLock, blocking: bool): try: - if 'blocking' in inspect.signature(AsyncEngineArgs).parameters: + if 'blocking' in inspect.signature(lock.acquire).parameters: lock.acquire(blocking=blocking) else: lock.acquire(timeout=(0 if not blocking else None)) From 17329d320d1bff7a433131932692cf147f2dd838 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Tue, 17 Mar 2026 15:48:32 +0800 Subject: [PATCH 28/56] wip --- src/twinkle/sampler/vllm_sampler/vllm_engine.py | 2 +- src/twinkle/sampler/vllm_sampler/vllm_sampler.py | 11 ++++------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/twinkle/sampler/vllm_sampler/vllm_engine.py b/src/twinkle/sampler/vllm_sampler/vllm_engine.py index 076d5f48..b7991bee 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_engine.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_engine.py @@ -53,7 +53,7 @@ def __init__( gpu_memory_utilization: float = 0.7, max_model_len: Optional[int] = None, max_num_seqs: int = 256, - enable_lora: bool = True, + enable_lora: bool = False, max_loras: int = 1, max_lora_rank: int = 32, enable_sleep_mode: bool = False, diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py index 10e64b39..583d1b10 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py @@ -170,17 +170,14 @@ def encode_trajectory_for_vllm(self, trajectory: Trajectory, adapter_name: str = new_content.append({'type': 'text', 'text': part}) msg['content'] = new_content if new_content else [{'type': 'text', 'text': ''}] - encoded = template.processor.apply_chat_template( - messages, - tokenize=True, - return_dict=True, + encoded = template.batch_encode( + [Trajectory(messages=messages)], add_generation_prompt=True, - return_tensors='pt', - ) + )[0] input_ids = encoded['input_ids'] if hasattr(input_ids, 'squeeze'): - input_ids = input_ids.squeeze(0) + input_ids = input_ids.squeeze() if hasattr(input_ids, 'tolist'): input_ids = input_ids.tolist() From fcb163be638a2134af5e27533ac8959ee1bc620f Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Tue, 17 Mar 2026 16:07:17 +0800 Subject: [PATCH 29/56] wip --- cookbook/transformers/fsdp2.py | 2 +- cookbook/transformers/fsdp2.sh | 2 +- src/twinkle/loss/cross_entropy.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py index 4d82451e..10d75df6 100644 --- a/cookbook/transformers/fsdp2.py +++ b/cookbook/transformers/fsdp2.py @@ -50,7 +50,7 @@ def train(): # Add a lora to model, with name `default` # Comment this to use full-parameter training - # model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2) + model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2) # Add Optimizer for lora `default` model.set_optimizer(optimizer_cls='AdamW', lr=1e-4) # Add LRScheduler for lora `default` diff --git a/cookbook/transformers/fsdp2.sh b/cookbook/transformers/fsdp2.sh index 93c531a9..46e9f27f 100644 --- a/cookbook/transformers/fsdp2.sh +++ b/cookbook/transformers/fsdp2.sh @@ -1 +1 @@ -CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 fsdp2.py +CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 fsdp2.py diff --git a/src/twinkle/loss/cross_entropy.py b/src/twinkle/loss/cross_entropy.py index 06bf791a..abcc9591 100644 --- a/src/twinkle/loss/cross_entropy.py +++ b/src/twinkle/loss/cross_entropy.py @@ -1,5 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from ..data_format import LossOutput +from twinkle.data_format import LossOutput from .base import Loss From b6332d944a0acd2247a9aefdbc6c761c1c3432d2 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Tue, 17 Mar 2026 16:17:27 +0800 Subject: [PATCH 30/56] wip --- src/twinkle/infra/collectors.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/twinkle/infra/collectors.py b/src/twinkle/infra/collectors.py index 3f8c7cf9..4411220d 100644 --- a/src/twinkle/infra/collectors.py +++ b/src/twinkle/infra/collectors.py @@ -1,4 +1,5 @@ from typing import List, Dict, Any, TYPE_CHECKING +import numpy as np if TYPE_CHECKING: import torch @@ -40,6 +41,9 @@ def collect_tensor_dict(outputs: List[Dict[str, Any]], **kwargs) -> Dict[str, An elif isinstance(first_value, dict): result[key] = collect_tensor_dict(values) + elif isinstance(first_value, np.ndarray): + raise NotImplementedError(f'Numpy array not supported for now.') + else: result[key] = values From 37d38c191ff4571ec69e187c568e3785dedafe2f Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Tue, 17 Mar 2026 17:48:11 +0800 Subject: [PATCH 31/56] lint code --- cookbook/rl/gkd_on_policy.py | 2 +- src/twinkle/infra/collectors.py | 15 ++++---- src/twinkle/loss/__init__.py | 1 - src/twinkle/loss/gkd.py | 11 ++---- src/twinkle/model/megatron/megatron.py | 7 ++-- .../model/megatron/multi_lora_megatron.py | 3 +- .../model/transformers/transformers.py | 3 +- src/twinkle/preprocessor/llm.py | 1 - .../sampler/vllm_sampler/vllm_engine.py | 34 ++++++++----------- .../sampler/vllm_sampler/vllm_sampler.py | 21 +++++++----- src/twinkle/template/base.py | 27 +++++++-------- src/twinkle/utils/parallel.py | 2 +- src/twinkle/utils/torch_utils.py | 2 +- 13 files changed, 61 insertions(+), 68 deletions(-) diff --git a/cookbook/rl/gkd_on_policy.py b/cookbook/rl/gkd_on_policy.py index e0788ddf..e2eaebcb 100644 --- a/cookbook/rl/gkd_on_policy.py +++ b/cookbook/rl/gkd_on_policy.py @@ -206,7 +206,7 @@ def main(): # 1. Student vLLM generates completions sample_response = student_sampler.sample(batch, SamplingParams(max_tokens=MAX_NEW_TOKENS, temperature=1.0, num_samples=N_SAMPLES)) input_data = [seq.new_input_feature for response in sample_response for seq in response.sequences] - + # 2. Teacher vLLM computes top-k prompt logprobs on generated sequences teacher_response = teacher_sampler.sample( input_data, diff --git a/src/twinkle/infra/collectors.py b/src/twinkle/infra/collectors.py index 4411220d..82c5fb3f 100644 --- a/src/twinkle/infra/collectors.py +++ b/src/twinkle/infra/collectors.py @@ -1,11 +1,14 @@ -from typing import List, Dict, Any, TYPE_CHECKING import numpy as np +from typing import TYPE_CHECKING, Any, Dict, List + +from twinkle import DeviceMesh if TYPE_CHECKING: import torch -def collect_tensor_dict(outputs: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]: +def collect_tensor_dict(outputs: List[Dict[str, Any]], device_mesh: DeviceMesh) -> Dict[str, Any]: + import torch if not outputs: return {} @@ -16,7 +19,7 @@ def collect_tensor_dict(outputs: List[Dict[str, Any]], **kwargs) -> Dict[str, An for d in outputs: all_keys.update(d.keys()) - import torch + outputs = [r for i, r in enumerate(outputs) if i in device_mesh.get_pp_last_ranks()] result = {} for key in all_keys: values = [d[key] for d in outputs if key in d] @@ -42,7 +45,7 @@ def collect_tensor_dict(outputs: List[Dict[str, Any]], **kwargs) -> Dict[str, An result[key] = collect_tensor_dict(values) elif isinstance(first_value, np.ndarray): - raise NotImplementedError(f'Numpy array not supported for now.') + raise NotImplementedError('Numpy array not supported for now.') else: result[key] = values @@ -53,7 +56,7 @@ def collect_tensor_dict(outputs: List[Dict[str, Any]], **kwargs) -> Dict[str, An def _pad_and_stack_tensors(tensors: List['torch.Tensor'], pad_value: float = 0) -> 'torch.Tensor': import torch if not tensors: - raise ValueError("Empty tensor list") + raise ValueError('Empty tensor list') if len(tensors) == 1: return tensors[0].unsqueeze(0) @@ -80,4 +83,4 @@ def _pad_and_stack_tensors(tensors: List['torch.Tensor'], pad_value: float = 0) padded = torch.nn.functional.pad(t, pad_params, value=pad_value) padded_tensors.append(padded) - return torch.cat(padded_tensors, dim=0) \ No newline at end of file + return torch.cat(padded_tensors, dim=0) diff --git a/src/twinkle/loss/__init__.py b/src/twinkle/loss/__init__.py index 65303dfd..7870f5a4 100644 --- a/src/twinkle/loss/__init__.py +++ b/src/twinkle/loss/__init__.py @@ -5,7 +5,6 @@ from .gkd import GKDLoss from .grpo import BNPOLoss, CISPOLoss, DRGRPOLoss, GRPOLoss, GSPOLoss, SAPOLoss from .mse import MSELoss -from .cross_entropy import CrossEntropyLoss torch_loss_mapping = { 'mse': MSELoss, diff --git a/src/twinkle/loss/gkd.py b/src/twinkle/loss/gkd.py index a46e4ffd..3de9179b 100644 --- a/src/twinkle/loss/gkd.py +++ b/src/twinkle/loss/gkd.py @@ -72,11 +72,8 @@ def __call__( Returns: LossOutput with scalar 'loss' averaged over valid (non-ignored) response tokens. """ - assert teacher_logits is not None or ( - teacher_topk_logprobs is not None and teacher_topk_indices is not None - ), ( - 'Either logits or both topk_logprobs and topk_indices must be provided.' - ) + assert teacher_logits is not None or (teacher_topk_logprobs is not None and teacher_topk_indices is not None), ( + 'Either logits or both topk_logprobs and topk_indices must be provided.') labels = inputs['labels'] student_logits = outputs['logits'] @@ -89,8 +86,6 @@ def __call__( teacher_topk_logprobs = teacher_topk_logprobs[:, :student_logits.shape[1]] teacher_topk_indices = teacher_topk_indices[:, :student_logits.shape[1]] - shifted_labels = labels - loss = self._generalized_jsd_loss( student_logits=student_logits, teacher_logits=teacher_logits, @@ -175,7 +170,7 @@ def _generalized_jsd_loss( elif stu_dim > tea_dim: teacher_logits = F.pad(teacher_logits, (0, stu_dim - tea_dim)) teacher_logits[..., tea_dim:] = student_logits[..., tea_dim:] - student_logits = student_logits[mask] # [num_valid, vocab/topk] + student_logits = student_logits[mask] # [num_valid, vocab/topk] teacher_logits = teacher_logits[mask] num_valid = mask.sum() else: diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index c796a6e4..522eb067 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -26,7 +26,8 @@ from twinkle.checkpoint_engine.mixin import CheckpointEngineMixin from twinkle.data_format import InputFeature, ModelOutput, Trajectory from twinkle.hub import HubOperation -from twinkle.loss import Loss, CrossEntropyLoss +from twinkle.infra import collect_tensor_dict +from twinkle.loss import CrossEntropyLoss, Loss from twinkle.metric import LossMetric, Metric, TrainMetric from twinkle.model.base import TwinkleModel from twinkle.patch import Patch, apply_patch @@ -339,7 +340,7 @@ def _postprocess_tensor_cp(self, tensor): def forward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs): raise NotImplementedError('Megatron only supports `forward_backward` and `forward_only`') - @remote_function(dispatch='slice_dp', collect='last_pp') + @remote_function(dispatch='slice_dp', collect=collect_tensor_dict) def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajectory]], @@ -364,7 +365,7 @@ def calculate_loss(self, **kwargs): def backward(self, **kwargs): raise NotImplementedError('Megatron only supports `forward_backward` and `forward_only`') - @remote_function(dispatch='slice_dp', collect='last_pp', sync=True) + @remote_function(dispatch='slice_dp', collect=collect_tensor_dict, sync=True) def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], diff --git a/src/twinkle/model/megatron/multi_lora_megatron.py b/src/twinkle/model/megatron/multi_lora_megatron.py index 9afbc056..4a2e3fdb 100644 --- a/src/twinkle/model/megatron/multi_lora_megatron.py +++ b/src/twinkle/model/megatron/multi_lora_megatron.py @@ -15,6 +15,7 @@ from twinkle.loss import Loss from twinkle.metric import Metric from twinkle.processor import InputProcessor +from ...infra import collect_tensor_dict from ..multi_lora import MultiLora from .megatron import MegatronModel from .strategy import MegatronStrategy @@ -111,7 +112,7 @@ def _check_adapter_valid(self, adapter_name: str): def _lazy_wrap_model(self): pass - @remote_function(dispatch='slice_dp', collect='last_pp', sync=True) + @remote_function(dispatch='slice_dp', collect=collect_tensor_dict, sync=True) def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajectory]], **kwargs): """Forward pass without gradient computation. diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 8733b10c..85352519 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -129,8 +129,7 @@ def accumulate_metrics(self, is_training): step=self.cur_step - 1, gradient_accumulation_steps=self.gradient_accumulation_steps, grad_norm=self._last_grad_norm, - loss_reduction=getattr(self.loss_instance, 'reduction', 'mean') - ) + loss_reduction=getattr(self.loss_instance, 'reduction', 'mean')) def calculate_metrics(self, is_training): self.accumulate_metrics(is_training) diff --git a/src/twinkle/preprocessor/llm.py b/src/twinkle/preprocessor/llm.py index 3a1c3f11..8d08a908 100644 --- a/src/twinkle/preprocessor/llm.py +++ b/src/twinkle/preprocessor/llm.py @@ -153,4 +153,3 @@ def preprocess(self, row) -> Trajectory: messages=messages, user_data=[('ground_truth', ground_truth)], ) - diff --git a/src/twinkle/sampler/vllm_sampler/vllm_engine.py b/src/twinkle/sampler/vllm_sampler/vllm_engine.py index b7991bee..911b11c3 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_engine.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_engine.py @@ -176,18 +176,16 @@ async def get_tokenizer(self): # Core Sampling API # ========================================================================= - async def sample( - self, - prompt_token_ids: List[int], - sampling_params: Union[SamplingParams, Dict[str, Any]], - lora_request: Optional[Any] = None, - request_id: Optional[str] = None, - priority: int = 0, - *, - images: Optional[List[Any]] = None, - videos: Optional[List[Any]] = None, - **kwargs - ) -> SampleResponse: + async def sample(self, + prompt_token_ids: List[int], + sampling_params: Union[SamplingParams, Dict[str, Any]], + lora_request: Optional[Any] = None, + request_id: Optional[str] = None, + priority: int = 0, + *, + images: Optional[List[Any]] = None, + videos: Optional[List[Any]] = None, + **kwargs) -> SampleResponse: """ Sample completions from the model. @@ -278,10 +276,8 @@ async def sample( seq_logprobs = [] for i, lp in enumerate(output.logprobs): if i < len(token_ids): - sorted_items = sorted( - lp.items(), key=lambda x: -(x[1].logprob))[:logprobs] - seq_logprobs.append([(tid, lp_obj.logprob) - for tid, lp_obj in sorted_items]) + sorted_items = sorted(lp.items(), key=lambda x: -(x[1].logprob))[:logprobs] + seq_logprobs.append([(tid, lp_obj.logprob) for tid, lp_obj in sorted_items]) # Map finish_reason to StopReason stop_reason: StopReason = 'length' @@ -319,10 +315,8 @@ async def sample( result_prompt_logprobs.append(None) # Get top-k logprobs - sorted_items = sorted( - lp_dict.items(), key=lambda x: -(x[1].logprob))[:prompt_logprobs_k] - result_topk_prompt_logprobs.append([(tid, lp_obj.logprob) - for tid, lp_obj in sorted_items]) + sorted_items = sorted(lp_dict.items(), key=lambda x: -(x[1].logprob))[:prompt_logprobs_k] + result_topk_prompt_logprobs.append([(tid, lp_obj.logprob) for tid, lp_obj in sorted_items]) return SampleResponse( sequences=sequences, prompt_logprobs=result_prompt_logprobs, diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py index 583d1b10..0be89feb 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py @@ -229,15 +229,18 @@ async def _sample_single( ) # response.sequences contains num_samples sequences for this prompt - return SampleResponse(sequences=[ - SampledSequence( - stop_reason=seq.stop_reason, - tokens=seq.tokens, - logprobs=seq.logprobs, - decoded=self.template.decode(seq.tokens), - new_input_feature=self.template.concat_input_feature(feat, seq.tokens), - ) for seq in response.sequences - ], prompt_logprobs=response.prompt_logprobs, topk_prompt_logprobs=response.topk_prompt_logprobs) + return SampleResponse( + sequences=[ + SampledSequence( + stop_reason=seq.stop_reason, + tokens=seq.tokens, + logprobs=seq.logprobs, + decoded=self.template.decode(seq.tokens), + new_input_feature=self.template.concat_input_feature(feat, seq.tokens), + ) for seq in response.sequences + ], + prompt_logprobs=response.prompt_logprobs, + topk_prompt_logprobs=response.topk_prompt_logprobs) @remote_function(dispatch='slice_dp', collect='flatten', lazy_collect=False) def sample( diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index 6a2055c5..167a459d 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -9,7 +9,7 @@ from twinkle.data_format import InputFeature, Message, Trajectory from twinkle.hub import HubOperation from twinkle.utils import load_image, to_device -from .utils import transfer_to_standard_message, TokenizeByRound +from .utils import TokenizeByRound, transfer_to_standard_message if TYPE_CHECKING: import torch @@ -52,7 +52,7 @@ def __init__(self, self._test_support_assistant_tokens_mask() self.pre_pipeline: List[Callable[[Trajectory], List[Trajectory]]] = [ self._add_default_system, # Add a default system field - self._to_standard_reasoning_content, # Convert thinking to standard field + self._to_standard_reasoning_content, # Convert thinking to standard field self._build_mm_messages, # turn to standard mm messages ] self.post_pipeline: List[Callable[[InputFeature], List[InputFeature]]] = [ @@ -190,16 +190,16 @@ def _extract_reasoning_content(messages: list[Message]) -> List[Message]: result = [] for message in messages: message = message.copy() - if message.get("role") == "assistant": - content = message.get("content", "") - if "reasoning_content" not in message and isinstance(content, str): - if "" in content: - reasoning_content = content.split("")[0].rstrip("\n").split("")[-1].lstrip( - "\n") - new_content = content.split("")[-1].lstrip("\n") + if message.get('role') == 'assistant': + content = message.get('content', '') + if 'reasoning_content' not in message and isinstance(content, str): + if '' in content: + reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip( + '\n') + new_content = content.split('')[-1].lstrip('\n') - message["reasoning_content"] = reasoning_content - message["content"] = new_content + message['reasoning_content'] = reasoning_content + message['content'] = new_content result.append(message) @@ -298,9 +298,8 @@ def encode(self, trajectory: Trajectory, add_generation_prompt: bool = False) -> assistant_masks = encoded.pop('assistant_masks') labels = np.where(assistant_masks, input_ids, -100) else: - input_ids, labels, encoded = TokenizeByRound.tokenize_with_assistant_labels(self.tokenizer, - self._apply_chat_template, - trajectory) + input_ids, labels, encoded = TokenizeByRound.tokenize_with_assistant_labels( + self.tokenizer, self._apply_chat_template, trajectory) else: assert len(trajectory['messages']) == 1 and trajectory['messages'][0]['role'] == 'user' text = trajectory['messages'][0]['content'] diff --git a/src/twinkle/utils/parallel.py b/src/twinkle/utils/parallel.py index 3f0287e8..d9d9d9ba 100644 --- a/src/twinkle/utils/parallel.py +++ b/src/twinkle/utils/parallel.py @@ -1,7 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import inspect import os import re -import inspect from contextlib import contextmanager from datasets.utils.filelock import FileLock diff --git a/src/twinkle/utils/torch_utils.py b/src/twinkle/utils/torch_utils.py index 7097f946..b0d454de 100644 --- a/src/twinkle/utils/torch_utils.py +++ b/src/twinkle/utils/torch_utils.py @@ -1,6 +1,6 @@ import socket from datetime import timedelta -from typing import TYPE_CHECKING, Any, Mapping, Union, List, Dict +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Union from .network import is_valid_ipv6_address From a01c524014dc430cdf198565a12334214d7e1018 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Tue, 17 Mar 2026 18:01:29 +0800 Subject: [PATCH 32/56] wip --- cookbook/rl/gkd_off_policy.py | 20 ++++++++++---------- src/twinkle/infra/collectors.py | 4 +++- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/cookbook/rl/gkd_off_policy.py b/cookbook/rl/gkd_off_policy.py index 89a0c68d..5b7460e1 100644 --- a/cookbook/rl/gkd_off_policy.py +++ b/cookbook/rl/gkd_off_policy.py @@ -48,7 +48,7 @@ from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.loss import GKDLoss -from twinkle.model import TransformersModel +from twinkle.model import MegatronModel from twinkle.preprocessor import GSM8KProcessor from twinkle.sampler import vLLMSampler from twinkle.template import Template @@ -56,8 +56,8 @@ logger = get_logger() # ── Configuration ───────────────────────────────────────────────────────────── -STUDENT_MODEL_ID = os.environ.get('STUDENT_MODEL_ID', 'ms://Qwen/Qwen3.5-2B') -TEACHER_MODEL_ID = os.environ.get('TEACHER_MODEL_ID', 'ms://Qwen/Qwen3.5-9B') +STUDENT_MODEL_ID = os.environ.get('STUDENT_MODEL_ID', 'ms://Qwen/Qwen3-0.6B') +TEACHER_MODEL_ID = os.environ.get('TEACHER_MODEL_ID', 'ms://Qwen/Qwen3-8B') MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 8)) SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 8)) @@ -171,7 +171,7 @@ def main(): logger.info(get_device_placement()) # ── Student model (trainable) ────────────────────────────────────────────── - student_model = TransformersModel( + student_model = MegatronModel( model_id=STUDENT_MODEL_ID, device_mesh=model_mesh, remote_group='student_model', @@ -181,8 +181,8 @@ def main(): LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, target_modules='all-linear'), gradient_accumulation_steps=1, ) - student_model.set_optimizer('AdamW', lr=LEARNING_RATE) - student_model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=0) + student_model.set_optimizer('default', lr=LEARNING_RATE) + student_model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS) student_model.set_loss(GKDLoss(beta=GKD_BETA, temperature=GKD_TEMPERATURE)) student_model.set_template('Template', model_id=STUDENT_MODEL_ID) @@ -228,14 +228,14 @@ def main(): # Student forward + GKD backward student_model.forward_backward(inputs=input_data, **teacher_output) student_model.clip_grad_and_step() - optim_step += 1 - if optim_step % 10 == 0: - metric = student_model.calculate_metric(is_training=True) - logger.info(f'[Step {optim_step}/{MAX_STEPS}] {metric}') + metric = student_model.calculate_metric(is_training=True) + logger.info(f'[Step {optim_step}/{MAX_STEPS}] {metric}') if optim_step % 50 == 0: student_model.save(f'gkd-offpolicy-ckpt-{optim_step}') + + optim_step += 1 student_model.save('gkd-offpolicy-final') logger.info('GKD off-policy training completed.') diff --git a/src/twinkle/infra/collectors.py b/src/twinkle/infra/collectors.py index 82c5fb3f..0e2b4f26 100644 --- a/src/twinkle/infra/collectors.py +++ b/src/twinkle/infra/collectors.py @@ -44,12 +44,14 @@ def collect_tensor_dict(outputs: List[Dict[str, Any]], device_mesh: DeviceMesh) elif isinstance(first_value, dict): result[key] = collect_tensor_dict(values) - elif isinstance(first_value, np.ndarray): + elif isinstance(first_value, np.ndarray) and first_value.size > 1: raise NotImplementedError('Numpy array not supported for now.') else: result[key] = values + if 'loss' in result and len(result['loss']) > 1: + result['loss'] = np.mean(result['loss']) return result From 45c09a183f7a6b291bf6f04d9776e9e089e2d9c7 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Tue, 17 Mar 2026 19:56:35 +0800 Subject: [PATCH 33/56] wip --- cookbook/megatron/tp.py | 8 ++++---- cookbook/rl/gkd_off_policy.py | 10 +++++----- src/twinkle/metric/train_metric.py | 6 ++++-- src/twinkle/model/megatron/megatron.py | 4 ++++ 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/cookbook/megatron/tp.py b/cookbook/megatron/tp.py index ee457fe7..8cbd7b0b 100644 --- a/cookbook/megatron/tp.py +++ b/cookbook/megatron/tp.py @@ -19,7 +19,7 @@ def eval(model): # 100 Samples dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100))) - dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B') + dataset.set_template('Template', model_id='ms://Qwen/Qwen3-0.6B') dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) dataset.encode() dataloader = DataLoader(dataset=dataset, batch_size=16) @@ -33,7 +33,7 @@ def train(): # 1000 samples dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000))) # Set template to prepare encoding - dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B') + dataset.set_template('Template', model_id='ms://Qwen/Qwen3-0.6B') # Preprocess the dataset to standard format dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) # Encode dataset @@ -41,7 +41,7 @@ def train(): # Global batch size = 1, dp_size = 1 dataloader = DataLoader(dataset=dataset, batch_size=16) # Use a MegatronModel - model = MegatronModel(model_id='ms://Qwen/Qwen3.5-4B') + model = MegatronModel(model_id='ms://Qwen/Qwen3-0.6B') lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') @@ -51,7 +51,7 @@ def train(): # Add Optimizer for lora `default` model.set_optimizer(optimizer_cls='default', lr=1e-4) # Add LRScheduler for lora `default` - model.set_lr_scheduler(scheduler_cls='default', lr_warmup_steps=5, lr_decay_steps=len(dataloader)) + model.set_lr_scheduler(scheduler_cls='default', lr_warmup_steps=0, lr_decay_steps=len(dataloader)) logger.info(get_device_placement()) # Print the training config logger.info(model.get_train_configs()) diff --git a/cookbook/rl/gkd_off_policy.py b/cookbook/rl/gkd_off_policy.py index 5b7460e1..af2d9c48 100644 --- a/cookbook/rl/gkd_off_policy.py +++ b/cookbook/rl/gkd_off_policy.py @@ -63,7 +63,7 @@ SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 8)) NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS -BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 16)) MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000)) LEARNING_RATE = float(os.environ.get('LR', 5e-5)) @@ -179,7 +179,6 @@ def main(): student_model.add_adapter_to_model( ADAPTER_NAME, LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, target_modules='all-linear'), - gradient_accumulation_steps=1, ) student_model.set_optimizer('default', lr=LEARNING_RATE) student_model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS) @@ -229,10 +228,11 @@ def main(): student_model.forward_backward(inputs=input_data, **teacher_output) student_model.clip_grad_and_step() - metric = student_model.calculate_metric(is_training=True) - logger.info(f'[Step {optim_step}/{MAX_STEPS}] {metric}') + if optim_step > 0: + metric = student_model.calculate_metric(is_training=True) + logger.info(f'[Step {optim_step}/{MAX_STEPS}] {metric}') - if optim_step % 50 == 0: + if optim_step % 50 == 0 and optim_step > 0: student_model.save(f'gkd-offpolicy-ckpt-{optim_step}') optim_step += 1 diff --git a/src/twinkle/metric/train_metric.py b/src/twinkle/metric/train_metric.py index 201ff859..da82a878 100644 --- a/src/twinkle/metric/train_metric.py +++ b/src/twinkle/metric/train_metric.py @@ -22,14 +22,16 @@ def __init__(self, device_mesh=None, process_group=None, **kwargs): self.gradient_accumulation_steps = 1 self.start_time = time.time() self.time = time.time() + self.lrs = [] def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: ModelOutput, **kwargs): lr = kwargs.get('lr') if isinstance(lr, list): - lr = [f'{x:.2e}' for x in lr] + lr = [f'{x:.6e}' for x in lr] else: - lr = f'{lr:.2e}' + lr = f'{lr:.6e}' self.lr = lr + self.lrs.append(lr) self.step = kwargs.get('step') self.gradient_accumulation_steps = kwargs.get('gradient_accumulation_steps', self.gradient_accumulation_steps) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 522eb067..bf78cf99 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -364,6 +364,10 @@ def calculate_loss(self, **kwargs): @remote_function() def backward(self, **kwargs): raise NotImplementedError('Megatron only supports `forward_backward` and `forward_only`') + + @remote_function(collect='first', lazy_collect=False) + def get_lr(self): + return self.optimizer_group['default']._get_lr() @remote_function(dispatch='slice_dp', collect=collect_tensor_dict, sync=True) def forward_backward(self, From 1c12fff2687ee441041602c30ee64648b95bcb5e Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Tue, 17 Mar 2026 23:27:26 +0800 Subject: [PATCH 34/56] wip --- cookbook/rl/gkd_off_policy.py | 5 ++--- src/twinkle/preprocessor/llm.py | 5 +---- src/twinkle/utils/torch_utils.py | 10 ---------- 3 files changed, 3 insertions(+), 17 deletions(-) diff --git a/cookbook/rl/gkd_off_policy.py b/cookbook/rl/gkd_off_policy.py index af2d9c48..d27ecb15 100644 --- a/cookbook/rl/gkd_off_policy.py +++ b/cookbook/rl/gkd_off_policy.py @@ -69,7 +69,7 @@ GKD_BETA = float(os.environ.get('GKD_BETA', 0.5)) GKD_TEMPERATURE = float(os.environ.get('GKD_TEMPERATURE', 1.0)) -GKD_TOPK = int(os.environ.get('GKD_TOPK', 20)) +GKD_TOPK = int(os.environ.get('GKD_TOPK', 64)) MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 2048)) N_SAMPLES = int(os.environ.get('N_SAMPLES', 1)) ADAPTER_NAME = 'default' @@ -188,7 +188,7 @@ def main(): # ── Teacher vLLM sampler (for prompt logprobs) ───────────────────────────── teacher_sampler = vLLMSampler( model_id=TEACHER_MODEL_ID, - engine_args={'gpu_memory_utilization': 0.85, 'max_model_len': 10240, 'logprobs_mode': 'raw_logprobs'}, + engine_args={'gpu_memory_utilization': 0.85, 'max_model_len': 10240, 'logprobs_mode': 'raw_logprobs', 'max_logprobs': 64}, device_mesh=sampler_mesh, remote_group='teacher_sampler', ) @@ -199,7 +199,6 @@ def main(): dataset=create_dataset, batch_size=BATCH_SIZE, min_batch_size=BATCH_SIZE, - device_mesh=model_mesh, remote_group='student_model', ) diff --git a/src/twinkle/preprocessor/llm.py b/src/twinkle/preprocessor/llm.py index 8d08a908..acaea41c 100644 --- a/src/twinkle/preprocessor/llm.py +++ b/src/twinkle/preprocessor/llm.py @@ -122,10 +122,7 @@ class GSM8KProcessor(Preprocessor): Extracts the ground truth number and stores it in user_data for reward. """ - system_prompt = ('You are a helpful math assistant. Solve the problem step by step. ' - 'Show your reasoning in tags, then give the final ' - 'numerical answer after ####.\n' - 'For example:\n ... reasoning ... \n#### 42') + system_prompt = ('You are a helpful math assistant. Solve the problem step by step and put your final answer within \\boxed{}.') def extract_ground_truth(self, answer_str: str) -> str: """Extract the number after '####' from GSM8K answer.""" diff --git a/src/twinkle/utils/torch_utils.py b/src/twinkle/utils/torch_utils.py index b0d454de..f13eb056 100644 --- a/src/twinkle/utils/torch_utils.py +++ b/src/twinkle/utils/torch_utils.py @@ -72,16 +72,6 @@ def selective_log_softmax(logits, index) -> 'torch.Tensor': import torch import torch.nn.functional as F - try: - from megatron.core import parallel_state as mpu - if mpu.get_tensor_model_parallel_world_size() >= 1: - try: - return _vocab_parallel_selective_log_softmax(logits, index) - except Exception: # noqa - import traceback - print(traceback.format_exc()) - except Exception: # noqa - pass if logits.dtype in [torch.float32, torch.float64]: selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1) # loop to reduce peak mem consumption From f2a1fc73291c5050adbfd8b871e8afc626d31ca2 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Tue, 17 Mar 2026 23:46:14 +0800 Subject: [PATCH 35/56] fix --- cookbook/rl/gkd_off_policy.py | 78 +++++++++++----------------- src/twinkle/preprocessor/__init__.py | 2 +- src/twinkle/preprocessor/llm.py | 26 +++++++++- 3 files changed, 55 insertions(+), 51 deletions(-) diff --git a/cookbook/rl/gkd_off_policy.py b/cookbook/rl/gkd_off_policy.py index d27ecb15..6aa0b544 100644 --- a/cookbook/rl/gkd_off_policy.py +++ b/cookbook/rl/gkd_off_policy.py @@ -49,7 +49,7 @@ from twinkle.dataset import Dataset, DatasetMeta from twinkle.loss import GKDLoss from twinkle.model import MegatronModel -from twinkle.preprocessor import GSM8KProcessor +from twinkle.preprocessor import GSM8KFullProcessor from twinkle.sampler import vLLMSampler from twinkle.template import Template @@ -70,8 +70,6 @@ GKD_BETA = float(os.environ.get('GKD_BETA', 0.5)) GKD_TEMPERATURE = float(os.environ.get('GKD_TEMPERATURE', 1.0)) GKD_TOPK = int(os.environ.get('GKD_TOPK', 64)) -MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 2048)) -N_SAMPLES = int(os.environ.get('N_SAMPLES', 1)) ADAPTER_NAME = 'default' @@ -81,72 +79,52 @@ def create_dataset(): """Full-text dataset with prompt + reference answer for off-policy distillation.""" dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train')) dataset.set_template('Template', model_id=STUDENT_MODEL_ID, max_length=1024) - dataset.map(GSM8KProcessor()) + dataset.map(GSM8KFullProcessor()) return dataset # ── Utility ─────────────────────────────────────────────────────────────────── def convert_topk_prompt_logprobs( - prompt_logprobs_batch: List[Optional[List[List[tuple]]]], - sequences_logprobs_batch: List[List[Optional[List[List[tuple]]]]], + topk_prompt_logprobs_batch: List[Optional[List[List[tuple]]]], ) -> dict: """Convert vLLM topk_prompt_logprobs to GKDLoss teacher_output format. Args: - prompt_logprobs_batch: [batch] each is topk_prompt_logprobs for one request. - Shape: [prompt_seq_len, topk] per request. - sequences_logprobs_batch: [batch][n_samples] each is generated logprobs. - Shape: [generated_len, topk] per sequence. + topk_prompt_logprobs_batch: [batch] each is topk_prompt_logprobs for one request. + Shape: [seq_len, topk] per request, where each position is List[(token_id, logprob)]. Returns: - Dict with expanded teacher logprobs/indices tensors. - Each prompt is expanded N times (one per generated sequence). + Dict with teacher logprobs/indices tensors. """ batch_logprobs = [] batch_indices = [] - for prompt_logprobs, sequences_logprobs in zip(prompt_logprobs_batch, sequences_logprobs_batch): - n_samples = len(sequences_logprobs) - - # Parse prompt logprobs (shared across all sequences) - # prompt_logprobs is List[float], expand to [seq_len, topk] with padding - prompt_lps = [] - prompt_ids = [] - if prompt_logprobs is not None: - for lp in prompt_logprobs: - if lp is None: - lp = -1 - # Expand single logprob to topk slots: [lp, 0, 0, ...] - prompt_lps.append([lp] + [0.0] * (GKD_TOPK - 1)) - prompt_ids.append([0] * GKD_TOPK) - - # Expand prompt and concat with each sequence's generated logprobs - for seq_logprobs in sequences_logprobs: - # Start with prompt logprobs (copy for each sequence) - seq_lps = list(prompt_lps) - seq_ids = list(prompt_ids) - - # Append generated token logprobs - if seq_logprobs is not None: - for pos_topk in seq_logprobs: - seq_lps.append([lp for _, lp in pos_topk]) - seq_ids.append([tid for tid, _ in pos_topk]) - - batch_logprobs.append(seq_lps) - batch_indices.append(seq_ids) + for seq_topk in topk_prompt_logprobs_batch: + seq_logprobs = [] + seq_indices = [] + if seq_topk is not None: + for pos_topk in seq_topk: + if pos_topk is None: + # First position is None, fill with placeholder + seq_logprobs.append([0.0] * GKD_TOPK) + seq_indices.append([0] * GKD_TOPK) + else: + seq_logprobs.append([lp for _, lp in pos_topk]) + seq_indices.append([tid for tid, _ in pos_topk]) + batch_logprobs.append(seq_logprobs) + batch_indices.append(seq_indices) # Pad to same seq_len within batch max_len = max(len(seq) for seq in batch_logprobs) if batch_logprobs else 1 - topk = len(batch_logprobs[0][0]) if batch_logprobs and batch_logprobs[0] else GKD_TOPK for i in range(len(batch_logprobs)): pad_len = max_len - len(batch_logprobs[i]) if pad_len > 0: - batch_logprobs[i].extend([[0.0] * topk] * pad_len) - batch_indices[i].extend([[0] * topk] * pad_len) + batch_logprobs[i].extend([[0.0] * GKD_TOPK] * pad_len) + batch_indices[i].extend([[0] * GKD_TOPK] * pad_len) - # In vllm output, the first position is None, we returns an invalid value(-10000), so roll it to match the labels + # Roll to align with labels (first position has no valid logprobs) return { 'teacher_topk_logprobs': torch.roll(torch.tensor(batch_logprobs, dtype=torch.float32), shifts=-1, dims=1), 'teacher_topk_indices': torch.roll(torch.tensor(batch_indices, dtype=torch.long), shifts=-1, dims=1), @@ -211,16 +189,18 @@ def main(): break # Teacher vLLM computes top-k prompt logprobs on the reference sequences + # max_tokens=1: don't generate new content, just compute logprobs on input teacher_response = teacher_sampler.sample( batch, - SamplingParams(max_tokens=MAX_NEW_TOKENS, temperature=1.0, prompt_logprobs=1, logprobs=GKD_TOPK, num_samples=N_SAMPLES), + SamplingParams(max_tokens=1, temperature=1.0, prompt_logprobs=GKD_TOPK, num_samples=1), ) - input_data = [seq.new_input_feature for response in teacher_response for seq in response.sequences] + + # Use original batch as input_data (dataset reference responses) + input_data = batch if isinstance(batch, list) else [batch] # Convert teacher logprobs to tensor format for GKDLoss teacher_output = convert_topk_prompt_logprobs( - [resp.prompt_logprobs for resp in teacher_response], - [[sequence.logprobs for sequence in resp.sequences] for resp in teacher_response], + [resp.topk_prompt_logprobs for resp in teacher_response], ) # Student forward + GKD backward diff --git a/src/twinkle/preprocessor/__init__.py b/src/twinkle/preprocessor/__init__.py index 13b52d99..7234a60a 100644 --- a/src/twinkle/preprocessor/__init__.py +++ b/src/twinkle/preprocessor/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .base import DataFilter, Preprocessor from .llm import (AlpacaProcessor, CompetitionMathGRPOProcessor, CompetitionMathProcessor, CountdownProcessor, - GSM8KProcessor, SelfCognitionProcessor) + GSM8KFullProcessor, GSM8KProcessor, SelfCognitionProcessor) diff --git a/src/twinkle/preprocessor/llm.py b/src/twinkle/preprocessor/llm.py index acaea41c..9cfc6f6e 100644 --- a/src/twinkle/preprocessor/llm.py +++ b/src/twinkle/preprocessor/llm.py @@ -116,10 +116,11 @@ def preprocess(self, row) -> Trajectory: class GSM8KProcessor(Preprocessor): - """Preprocessor for GSM8K dataset. + """Preprocessor for GSM8K dataset (prompt-only, for on-policy generation). GSM8K fields: question (str), answer (str ending with '#### ') Extracts the ground truth number and stores it in user_data for reward. + Only includes system + user messages; assistant response is generated on-policy. """ system_prompt = ('You are a helpful math assistant. Solve the problem step by step and put your final answer within \\boxed{}.') @@ -150,3 +151,26 @@ def preprocess(self, row) -> Trajectory: messages=messages, user_data=[('ground_truth', ground_truth)], ) + + +class GSM8KFullProcessor(GSM8KProcessor): + """Preprocessor for GSM8K dataset (full trajectory, for off-policy distillation). + + Includes system + user + assistant messages with the reference answer. + Used when training on existing responses (off-policy) rather than generating new ones. + """ + + def preprocess(self, row) -> Trajectory: + question = row['question'] + answer = row.get('answer', '') + ground_truth = self.extract_ground_truth(answer) + + messages = [ + Message(role='system', content=self.system_prompt), + Message(role='user', content=question), + Message(role='assistant', content=answer), + ] + return Trajectory( + messages=messages, + user_data=[('ground_truth', ground_truth)], + ) From 519dba974882dc3e18c916dfda01ce4540838f98 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 18 Mar 2026 00:38:02 +0800 Subject: [PATCH 36/56] wip --- cookbook/rl/gkd_off_policy.py | 7 +-- src/twinkle/preprocessor/llm.py | 2 +- .../sampler/vllm_sampler/vllm_sampler.py | 58 ++++++++++++------- 3 files changed, 41 insertions(+), 26 deletions(-) diff --git a/cookbook/rl/gkd_off_policy.py b/cookbook/rl/gkd_off_policy.py index 6aa0b544..3d234d87 100644 --- a/cookbook/rl/gkd_off_policy.py +++ b/cookbook/rl/gkd_off_policy.py @@ -189,14 +189,13 @@ def main(): break # Teacher vLLM computes top-k prompt logprobs on the reference sequences - # max_tokens=1: don't generate new content, just compute logprobs on input + # max_tokens=0: don't generate new content, just compute logprobs on input teacher_response = teacher_sampler.sample( batch, - SamplingParams(max_tokens=1, temperature=1.0, prompt_logprobs=GKD_TOPK, num_samples=1), + SamplingParams(max_tokens=0, temperature=1.0, prompt_logprobs=GKD_TOPK, num_samples=1), ) - # Use original batch as input_data (dataset reference responses) - input_data = batch if isinstance(batch, list) else [batch] + input_data = [seq.new_input_feature for response in teacher_response for seq in response.sequences] # Convert teacher logprobs to tensor format for GKDLoss teacher_output = convert_topk_prompt_logprobs( diff --git a/src/twinkle/preprocessor/llm.py b/src/twinkle/preprocessor/llm.py index 9cfc6f6e..1b763a6a 100644 --- a/src/twinkle/preprocessor/llm.py +++ b/src/twinkle/preprocessor/llm.py @@ -123,7 +123,7 @@ class GSM8KProcessor(Preprocessor): Only includes system + user messages; assistant response is generated on-policy. """ - system_prompt = ('You are a helpful math assistant. Solve the problem step by step and put your final answer within \\boxed{}.') + system_prompt = ('You are a helpful math assistant. Solve the problem step by step and put your final answer within #### ') def extract_ground_truth(self, answer_str: str) -> str: """Extract the number after '####' from GSM8K answer.""" diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py index 0be89feb..04db3ba6 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py @@ -125,7 +125,7 @@ async def _create_engine_async(self, engine_cls, model_id, engine_kwargs): """Create engine in async context to ensure output_handler starts correctly.""" return engine_cls(model_id=model_id, **engine_kwargs) - def encode_trajectory_for_vllm(self, trajectory: Trajectory, adapter_name: str = '') -> InputFeature: + def encode_trajectory_for_vllm(self, trajectory: Trajectory, adapter_name: str = '', add_generation_prompt=True) -> InputFeature: """Encode trajectory for vLLM - does not expand image tokens. Args: @@ -172,7 +172,7 @@ def encode_trajectory_for_vllm(self, trajectory: Trajectory, adapter_name: str = encoded = template.batch_encode( [Trajectory(messages=messages)], - add_generation_prompt=True, + add_generation_prompt=add_generation_prompt, )[0] input_ids = encoded['input_ids'] @@ -182,7 +182,7 @@ def encode_trajectory_for_vllm(self, trajectory: Trajectory, adapter_name: str = input_ids = input_ids.tolist() result = trajectory - result['input_ids'] = input_ids + result.update(encoded) # Attach preprocessed images/videos for vLLM if images: @@ -197,7 +197,7 @@ async def _sample_single( sampling_params: SamplingParams, lora_request: Optional[Any] = None, *, - logprobs: bool = True, + logprobs_only: bool = False, ) -> List[SampledSequence]: """Sample a single input asynchronously. @@ -207,7 +207,7 @@ async def _sample_single( adapter_path: Optional LoRA adapter path (legacy, prefer lora_request). lora_request: Pre-built LoRARequest to attach to the sampling request. Avoids repeated ``_get_or_load_lora`` calls per input. - num_samples: Number of completions to generate for this prompt. + logprobs_only: Only return logprobs (no generated tokens). Returns: List of num_samples SampledSequence objects. @@ -222,25 +222,36 @@ async def _sample_single( response = await self.engine.sample( prompt_token_ids=input_ids, sampling_params=sampling_params, - logprobs=logprobs, lora_request=lora_request, images=images, videos=videos, ) - # response.sequences contains num_samples sequences for this prompt - return SampleResponse( - sequences=[ - SampledSequence( - stop_reason=seq.stop_reason, - tokens=seq.tokens, - logprobs=seq.logprobs, - decoded=self.template.decode(seq.tokens), - new_input_feature=self.template.concat_input_feature(feat, seq.tokens), - ) for seq in response.sequences - ], - prompt_logprobs=response.prompt_logprobs, - topk_prompt_logprobs=response.topk_prompt_logprobs) + if not logprobs_only: + # response.sequences contains num_samples sequences for this prompt + return SampleResponse( + sequences=[ + SampledSequence( + stop_reason=seq.stop_reason, + tokens=seq.tokens, + logprobs=seq.logprobs, + decoded=self.template.decode(seq.tokens), + new_input_feature=self.template.concat_input_feature(feat, seq.tokens), + ) for seq in response.sequences + ], + prompt_logprobs=response.prompt_logprobs, + topk_prompt_logprobs=response.topk_prompt_logprobs) + else: + return SampleResponse( + sequences=[ + SampledSequence( + tokens=[], + stop_reason=seq.stop_reason, + new_input_feature=feat, + ) for seq in response.sequences + ], + prompt_logprobs=response.prompt_logprobs, + topk_prompt_logprobs=response.topk_prompt_logprobs) @remote_function(dispatch='slice_dp', collect='flatten', lazy_collect=False) def sample( @@ -286,12 +297,16 @@ def sample( # Check if inputs are Trajectory (not encoded) - aligned with Model.forward logic is_trajectory = self._is_trajectory(inputs) - + logprobs_only = False + if sampling_params.max_tokens == 0: + sampling_params.max_tokens = 1 + logprobs_only = True + if is_trajectory: template = self.template assert template is not None, \ 'Use set_template to add a template when trying to input Trajectory' - encoded_inputs = [self.encode_trajectory_for_vllm(traj, adapter_name) for traj in inputs_list] + encoded_inputs = [self.encode_trajectory_for_vllm(traj, adapter_name, not logprobs_only) for traj in inputs_list] else: encoded_inputs = inputs_list @@ -309,6 +324,7 @@ async def _sample_all(): feat, sampling_params, lora_request=lora_request, + logprobs_only=logprobs_only, ) for feat in encoded_inputs ] return await asyncio.gather(*tasks) From 2576f1852b820f93fa128b9ee8c1de8b73843f27 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 18 Mar 2026 10:57:20 +0800 Subject: [PATCH 37/56] fix --- cookbook/rl/gkd_on_policy.py | 72 +++++++++++++++++++++++++++--------- 1 file changed, 54 insertions(+), 18 deletions(-) diff --git a/cookbook/rl/gkd_on_policy.py b/cookbook/rl/gkd_on_policy.py index e2eaebcb..6faea2bd 100644 --- a/cookbook/rl/gkd_on_policy.py +++ b/cookbook/rl/gkd_on_policy.py @@ -37,6 +37,8 @@ """ import os +import threading +from queue import Queue from typing import List, Optional import torch @@ -198,39 +200,73 @@ def main(): logger.info(f'GKD On-Policy | student={STUDENT_MODEL_ID} teacher={TEACHER_MODEL_ID}') logger.info(f' beta={GKD_BETA} T={GKD_TEMPERATURE} topk={GKD_TOPK}') + # ── Async sampling with queue ─────────────────────────────────────────────── + sample_queue: Queue = Queue(maxsize=2) # Prefetch up to 2 batches + stop_event = threading.Event() + + def sample_producer(): + """Background thread: sample from student/teacher and put results in queue.""" + for batch in dataloader: + if stop_event.is_set(): + break + + # 1. Student vLLM generates completions + sample_response = student_sampler.sample( + batch, SamplingParams(max_tokens=MAX_NEW_TOKENS, temperature=1.0, num_samples=N_SAMPLES) + ) + input_data = [seq.new_input_feature for response in sample_response for seq in response.sequences] + + # 2. Teacher vLLM computes top-k prompt logprobs on generated sequences + teacher_response = teacher_sampler.sample( + input_data, + SamplingParams(max_tokens=0, temperature=1.0, prompt_logprobs=GKD_TOPK), + ) + + # 3. Convert teacher logprobs to tensor format for GKDLoss + teacher_output = convert_topk_prompt_logprobs( + [resp.topk_prompt_logprobs for resp in teacher_response], + ) + + # Put (input_data, teacher_output) into queue + sample_queue.put((input_data, teacher_output)) + + # Signal end of data + sample_queue.put(None) + + # Start sampling thread + producer_thread = threading.Thread(target=sample_producer, daemon=True) + producer_thread.start() + + # ── Training loop (consume from queue) ────────────────────────────────────── optim_step = 0 - for batch in dataloader: + while True: if optim_step >= MAX_STEPS: + stop_event.set() break - # 1. Student vLLM generates completions - sample_response = student_sampler.sample(batch, SamplingParams(max_tokens=MAX_NEW_TOKENS, temperature=1.0, num_samples=N_SAMPLES)) - input_data = [seq.new_input_feature for response in sample_response for seq in response.sequences] - - # 2. Teacher vLLM computes top-k prompt logprobs on generated sequences - teacher_response = teacher_sampler.sample( - input_data, - SamplingParams(max_tokens=1, temperature=1.0, prompt_logprobs=GKD_TOPK), - ) + # Get data from queue (blocking) + item = sample_queue.get() + if item is None: # End of data + break - # 3. Convert teacher logprobs to tensor format for GKDLoss - # teacher_response is List[SampleResponse], extract topk_prompt_logprobs from each - teacher_output = convert_topk_prompt_logprobs( - [resp.topk_prompt_logprobs for resp in teacher_response], - ) + input_data, teacher_output = item - # 4. Student forward + GKD backward + # Student forward + GKD backward student_model.forward_backward(inputs=input_data, **teacher_output) student_model.clip_grad_and_step() optim_step += 1 - if optim_step % 10 == 0: + if optim_step > 0 and optim_step % 10 == 0: metric = student_model.calculate_metric(is_training=True) logger.info(f'[Step {optim_step}/{MAX_STEPS}] {metric}') - if optim_step % 50 == 0: + if optim_step > 0 and optim_step % 50 == 0: student_model.save(f'gkd-onpolicy-ckpt-{optim_step}') + # Wait for producer to finish + stop_event.set() + producer_thread.join(timeout=5) + student_model.save('gkd-onpolicy-final') logger.info('GKD on-policy training completed.') From e23ee4184f70845508794f4978782f7c6382a129 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 18 Mar 2026 11:19:31 +0800 Subject: [PATCH 38/56] fix --- .../sampler/vllm_sampler/vllm_sampler.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py index 04db3ba6..85181a8e 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py @@ -25,6 +25,8 @@ import threading from typing import Any, Dict, List, Optional, Union +import numpy as np + from twinkle import DeviceMesh, get_logger, remote_class, remote_function, requires from twinkle.checkpoint_engine import CheckpointEngineMixin from twinkle.data_format import InputFeature, SampledSequence, SampleResponse, SamplingParams, Trajectory @@ -35,6 +37,17 @@ logger = get_logger() +def _convert_ndarray_to_list(obj: Any) -> Any: + """Recursively convert numpy arrays to lists in a dict/list structure.""" + if isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, dict): + return {k: _convert_ndarray_to_list(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [_convert_ndarray_to_list(item) for item in obj] + return obj + + @remote_class() class vLLMSampler(Sampler, CheckpointEngineMixin): """A vLLM-based sampler using VLLMEngine (AsyncLLM). @@ -236,7 +249,9 @@ async def _sample_single( tokens=seq.tokens, logprobs=seq.logprobs, decoded=self.template.decode(seq.tokens), - new_input_feature=self.template.concat_input_feature(feat, seq.tokens), + new_input_feature=_convert_ndarray_to_list( + self.template.concat_input_feature(feat, seq.tokens) + ), ) for seq in response.sequences ], prompt_logprobs=response.prompt_logprobs, @@ -247,7 +262,7 @@ async def _sample_single( SampledSequence( tokens=[], stop_reason=seq.stop_reason, - new_input_feature=feat, + new_input_feature=_convert_ndarray_to_list(feat), ) for seq in response.sequences ], prompt_logprobs=response.prompt_logprobs, From 23706997270ab07362762c9874842b4196a8c8be Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 18 Mar 2026 12:33:33 +0800 Subject: [PATCH 39/56] Revert "fix" This reverts commit 2576f1852b820f93fa128b9ee8c1de8b73843f27. --- cookbook/rl/gkd_on_policy.py | 72 +++++++++--------------------------- 1 file changed, 18 insertions(+), 54 deletions(-) diff --git a/cookbook/rl/gkd_on_policy.py b/cookbook/rl/gkd_on_policy.py index 6faea2bd..e2eaebcb 100644 --- a/cookbook/rl/gkd_on_policy.py +++ b/cookbook/rl/gkd_on_policy.py @@ -37,8 +37,6 @@ """ import os -import threading -from queue import Queue from typing import List, Optional import torch @@ -200,73 +198,39 @@ def main(): logger.info(f'GKD On-Policy | student={STUDENT_MODEL_ID} teacher={TEACHER_MODEL_ID}') logger.info(f' beta={GKD_BETA} T={GKD_TEMPERATURE} topk={GKD_TOPK}') - # ── Async sampling with queue ─────────────────────────────────────────────── - sample_queue: Queue = Queue(maxsize=2) # Prefetch up to 2 batches - stop_event = threading.Event() - - def sample_producer(): - """Background thread: sample from student/teacher and put results in queue.""" - for batch in dataloader: - if stop_event.is_set(): - break - - # 1. Student vLLM generates completions - sample_response = student_sampler.sample( - batch, SamplingParams(max_tokens=MAX_NEW_TOKENS, temperature=1.0, num_samples=N_SAMPLES) - ) - input_data = [seq.new_input_feature for response in sample_response for seq in response.sequences] - - # 2. Teacher vLLM computes top-k prompt logprobs on generated sequences - teacher_response = teacher_sampler.sample( - input_data, - SamplingParams(max_tokens=0, temperature=1.0, prompt_logprobs=GKD_TOPK), - ) - - # 3. Convert teacher logprobs to tensor format for GKDLoss - teacher_output = convert_topk_prompt_logprobs( - [resp.topk_prompt_logprobs for resp in teacher_response], - ) - - # Put (input_data, teacher_output) into queue - sample_queue.put((input_data, teacher_output)) - - # Signal end of data - sample_queue.put(None) - - # Start sampling thread - producer_thread = threading.Thread(target=sample_producer, daemon=True) - producer_thread.start() - - # ── Training loop (consume from queue) ────────────────────────────────────── optim_step = 0 - while True: + for batch in dataloader: if optim_step >= MAX_STEPS: - stop_event.set() break - # Get data from queue (blocking) - item = sample_queue.get() - if item is None: # End of data - break + # 1. Student vLLM generates completions + sample_response = student_sampler.sample(batch, SamplingParams(max_tokens=MAX_NEW_TOKENS, temperature=1.0, num_samples=N_SAMPLES)) + input_data = [seq.new_input_feature for response in sample_response for seq in response.sequences] - input_data, teacher_output = item + # 2. Teacher vLLM computes top-k prompt logprobs on generated sequences + teacher_response = teacher_sampler.sample( + input_data, + SamplingParams(max_tokens=1, temperature=1.0, prompt_logprobs=GKD_TOPK), + ) - # Student forward + GKD backward + # 3. Convert teacher logprobs to tensor format for GKDLoss + # teacher_response is List[SampleResponse], extract topk_prompt_logprobs from each + teacher_output = convert_topk_prompt_logprobs( + [resp.topk_prompt_logprobs for resp in teacher_response], + ) + + # 4. Student forward + GKD backward student_model.forward_backward(inputs=input_data, **teacher_output) student_model.clip_grad_and_step() optim_step += 1 - if optim_step > 0 and optim_step % 10 == 0: + if optim_step % 10 == 0: metric = student_model.calculate_metric(is_training=True) logger.info(f'[Step {optim_step}/{MAX_STEPS}] {metric}') - if optim_step > 0 and optim_step % 50 == 0: + if optim_step % 50 == 0: student_model.save(f'gkd-onpolicy-ckpt-{optim_step}') - # Wait for producer to finish - stop_event.set() - producer_thread.join(timeout=5) - student_model.save('gkd-onpolicy-final') logger.info('GKD on-policy training completed.') From 7e83bdcc090928aebee06e972833ea09876a329d Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 18 Mar 2026 12:35:52 +0800 Subject: [PATCH 40/56] fix --- cookbook/rl/gkd_on_policy.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/cookbook/rl/gkd_on_policy.py b/cookbook/rl/gkd_on_policy.py index e2eaebcb..fec3c04a 100644 --- a/cookbook/rl/gkd_on_policy.py +++ b/cookbook/rl/gkd_on_policy.py @@ -44,6 +44,7 @@ import twinkle from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger +from twinkle.checkpoint_engine import CheckpointEngineManager from twinkle.data_format import SamplingParams from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta @@ -171,7 +172,7 @@ def main(): # ── Student vLLM sampler (for on-policy generation) ──────────────────────── student_sampler = vLLMSampler( model_id=STUDENT_MODEL_ID, - engine_args={'gpu_memory_utilization': 0.85, 'max_model_len': 2048}, + engine_args={'gpu_memory_utilization': 0.85, 'max_model_len': 2048, 'enable_lora': True, 'max_loras': 1}, device_mesh=sampler_mesh, remote_group='student_sampler', ) @@ -195,6 +196,9 @@ def main(): remote_group='student_model', ) + # ── Checkpoint manager for weight sync ────────────────────────────────────── + ckpt_manager = CheckpointEngineManager(model=student_model, sampler=student_sampler) + logger.info(f'GKD On-Policy | student={STUDENT_MODEL_ID} teacher={TEACHER_MODEL_ID}') logger.info(f' beta={GKD_BETA} T={GKD_TEMPERATURE} topk={GKD_TOPK}') @@ -203,23 +207,27 @@ def main(): if optim_step >= MAX_STEPS: break - # 1. Student vLLM generates completions + # 1. Sync student model weights to student sampler + ckpt_manager.sync_weights(merge_and_sync=False) + student_sampler.reset_prefix_cache() + + # 2. Student vLLM generates completions sample_response = student_sampler.sample(batch, SamplingParams(max_tokens=MAX_NEW_TOKENS, temperature=1.0, num_samples=N_SAMPLES)) input_data = [seq.new_input_feature for response in sample_response for seq in response.sequences] - # 2. Teacher vLLM computes top-k prompt logprobs on generated sequences + # 3. Teacher vLLM computes top-k prompt logprobs on generated sequences teacher_response = teacher_sampler.sample( input_data, SamplingParams(max_tokens=1, temperature=1.0, prompt_logprobs=GKD_TOPK), ) - # 3. Convert teacher logprobs to tensor format for GKDLoss + # 4. Convert teacher logprobs to tensor format for GKDLoss # teacher_response is List[SampleResponse], extract topk_prompt_logprobs from each teacher_output = convert_topk_prompt_logprobs( [resp.topk_prompt_logprobs for resp in teacher_response], ) - # 4. Student forward + GKD backward + # 5. Student forward + GKD backward student_model.forward_backward(inputs=input_data, **teacher_output) student_model.clip_grad_and_step() optim_step += 1 From ff377897f7c9cb7f6f172286456f6d25781176ea Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 18 Mar 2026 12:37:25 +0800 Subject: [PATCH 41/56] fix --- cookbook/rl/gkd_on_policy.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/cookbook/rl/gkd_on_policy.py b/cookbook/rl/gkd_on_policy.py index fec3c04a..58467af1 100644 --- a/cookbook/rl/gkd_on_policy.py +++ b/cookbook/rl/gkd_on_policy.py @@ -60,19 +60,19 @@ STUDENT_MODEL_ID = os.environ.get('STUDENT_MODEL_ID', 'ms://Qwen/Qwen3-0.6B') TEACHER_MODEL_ID = os.environ.get('TEACHER_MODEL_ID', 'ms://Qwen/Qwen3-8B') -MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) -SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 2)) +MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 8)) +SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4)) NUM_GPUS = MODEL_GPUS + 2*SAMPLER_GPUS -MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 512)) +MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 1024)) BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) -MAX_STEPS = int(os.environ.get('MAX_STEPS', 200)) -LEARNING_RATE = float(os.environ.get('LR', 1e-4)) -N_SAMPLES = int(os.environ.get('N_SAMPLES', 8)) +MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000)) +LEARNING_RATE = float(os.environ.get('LR', 5e-5)) +N_SAMPLES = int(os.environ.get('N_SAMPLES', 1)) GKD_BETA = float(os.environ.get('GKD_BETA', 0.5)) GKD_TEMPERATURE = float(os.environ.get('GKD_TEMPERATURE', 1.0)) -GKD_TOPK = int(os.environ.get('GKD_TOPK', 20)) +GKD_TOPK = int(os.environ.get('GKD_TOPK', 64)) ADAPTER_NAME = 'default' @@ -151,7 +151,6 @@ def main(): nproc_per_node=NUM_GPUS, groups=device_groups, ) - logger.info(get_device_placement()) # ── Student model (trainable) ────────────────────────────────────────────── student_model = TransformersModel( @@ -181,7 +180,7 @@ def main(): # ── Teacher vLLM sampler (for prompt logprobs) ─────────────────────────────── teacher_sampler = vLLMSampler( model_id=TEACHER_MODEL_ID, - engine_args={'gpu_memory_utilization': 0.85, 'max_model_len': 2048, 'logprobs_mode': 'raw_logprobs'}, + engine_args={'gpu_memory_utilization': 0.85, 'max_model_len': 2048, 'logprobs_mode': 'raw_logprobs', 'max_logprobs': 64}, device_mesh=sampler_mesh, remote_group='teacher_sampler', ) @@ -199,6 +198,7 @@ def main(): # ── Checkpoint manager for weight sync ────────────────────────────────────── ckpt_manager = CheckpointEngineManager(model=student_model, sampler=student_sampler) + logger.info(get_device_placement()) logger.info(f'GKD On-Policy | student={STUDENT_MODEL_ID} teacher={TEACHER_MODEL_ID}') logger.info(f' beta={GKD_BETA} T={GKD_TEMPERATURE} topk={GKD_TOPK}') From a6205ce8ded2ef5a68770d1b7260967252303922 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 18 Mar 2026 14:05:01 +0800 Subject: [PATCH 42/56] fix --- cookbook/rl/gkd_on_policy.py | 10 +++++----- src/twinkle/sampler/vllm_sampler/vllm_sampler.py | 12 +++++++++--- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/cookbook/rl/gkd_on_policy.py b/cookbook/rl/gkd_on_policy.py index 58467af1..d45ddda6 100644 --- a/cookbook/rl/gkd_on_policy.py +++ b/cookbook/rl/gkd_on_policy.py @@ -64,8 +64,8 @@ SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4)) NUM_GPUS = MODEL_GPUS + 2*SAMPLER_GPUS -MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 1024)) -BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) +MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 2048)) +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 32)) MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000)) LEARNING_RATE = float(os.environ.get('LR', 5e-5)) N_SAMPLES = int(os.environ.get('N_SAMPLES', 1)) @@ -171,7 +171,7 @@ def main(): # ── Student vLLM sampler (for on-policy generation) ──────────────────────── student_sampler = vLLMSampler( model_id=STUDENT_MODEL_ID, - engine_args={'gpu_memory_utilization': 0.85, 'max_model_len': 2048, 'enable_lora': True, 'max_loras': 1}, + engine_args={'gpu_memory_utilization': 0.85, 'max_model_len': 4096, 'enable_lora': True, 'max_loras': 1}, device_mesh=sampler_mesh, remote_group='student_sampler', ) @@ -180,7 +180,7 @@ def main(): # ── Teacher vLLM sampler (for prompt logprobs) ─────────────────────────────── teacher_sampler = vLLMSampler( model_id=TEACHER_MODEL_ID, - engine_args={'gpu_memory_utilization': 0.85, 'max_model_len': 2048, 'logprobs_mode': 'raw_logprobs', 'max_logprobs': 64}, + engine_args={'gpu_memory_utilization': 0.85, 'max_model_len': 4096, 'logprobs_mode': 'raw_logprobs', 'max_logprobs': 64}, device_mesh=sampler_mesh, remote_group='teacher_sampler', ) @@ -218,7 +218,7 @@ def main(): # 3. Teacher vLLM computes top-k prompt logprobs on generated sequences teacher_response = teacher_sampler.sample( input_data, - SamplingParams(max_tokens=1, temperature=1.0, prompt_logprobs=GKD_TOPK), + SamplingParams(max_tokens=0, temperature=1.0, prompt_logprobs=GKD_TOPK), ) # 4. Convert teacher logprobs to tensor format for GKDLoss diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py index 85181a8e..ec991ed1 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py @@ -38,13 +38,19 @@ def _convert_ndarray_to_list(obj: Any) -> Any: - """Recursively convert numpy arrays to lists in a dict/list structure.""" if isinstance(obj, np.ndarray): return obj.tolist() + elif isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.bool_): + return bool(obj) elif isinstance(obj, dict): return {k: _convert_ndarray_to_list(v) for k, v in obj.items()} - elif isinstance(obj, list): - return [_convert_ndarray_to_list(item) for item in obj] + elif isinstance(obj, (list, tuple)): + converted = [_convert_ndarray_to_list(item) for item in obj] + return type(obj)(converted) if isinstance(obj, tuple) else converted return obj From 781514ebc5b8d9f676356b5915f5d8f2bf24baab Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 18 Mar 2026 18:01:08 +0800 Subject: [PATCH 43/56] wip --- cookbook/rl/gkd_on_policy.py | 6 +++--- src/twinkle/checkpoint_engine/manager.py | 3 --- src/twinkle/infra/collectors.py | 2 +- src/twinkle/model/transformers/transformers.py | 2 +- 4 files changed, 5 insertions(+), 8 deletions(-) diff --git a/cookbook/rl/gkd_on_policy.py b/cookbook/rl/gkd_on_policy.py index d45ddda6..e73d46b1 100644 --- a/cookbook/rl/gkd_on_policy.py +++ b/cookbook/rl/gkd_on_policy.py @@ -60,12 +60,12 @@ STUDENT_MODEL_ID = os.environ.get('STUDENT_MODEL_ID', 'ms://Qwen/Qwen3-0.6B') TEACHER_MODEL_ID = os.environ.get('TEACHER_MODEL_ID', 'ms://Qwen/Qwen3-8B') -MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 8)) +MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4)) NUM_GPUS = MODEL_GPUS + 2*SAMPLER_GPUS MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 2048)) -BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 32)) +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 16)) MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000)) LEARNING_RATE = float(os.environ.get('LR', 5e-5)) N_SAMPLES = int(os.environ.get('N_SAMPLES', 1)) @@ -228,7 +228,7 @@ def main(): ) # 5. Student forward + GKD backward - student_model.forward_backward(inputs=input_data, **teacher_output) + student_model.forward_backward(inputs=input_data, **teacher_output)() student_model.clip_grad_and_step() optim_step += 1 diff --git a/src/twinkle/checkpoint_engine/manager.py b/src/twinkle/checkpoint_engine/manager.py index 29aaaec7..89c2ce0a 100644 --- a/src/twinkle/checkpoint_engine/manager.py +++ b/src/twinkle/checkpoint_engine/manager.py @@ -130,6 +130,3 @@ def sync_weights(self, merge_and_sync=True): if not self.base_sync_done: self.base_sync_done = True logger.info('Base model sync completed, subsequent syncs will be LoRA-only') - - elapsed = time.time() - start_time - logger.info(f'Weight sync completed in {elapsed:.2f}s') diff --git a/src/twinkle/infra/collectors.py b/src/twinkle/infra/collectors.py index 0e2b4f26..af4d6d6e 100644 --- a/src/twinkle/infra/collectors.py +++ b/src/twinkle/infra/collectors.py @@ -55,7 +55,7 @@ def collect_tensor_dict(outputs: List[Dict[str, Any]], device_mesh: DeviceMesh) return result -def _pad_and_stack_tensors(tensors: List['torch.Tensor'], pad_value: float = 0) -> 'torch.Tensor': +def _pad_and_stack_tensors(tensors: List['torch.Tensor'], pad_value: float = -200) -> 'torch.Tensor': import torch if not tensors: raise ValueError('Empty tensor list') diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 85352519..93811dfc 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -514,7 +514,7 @@ def backward(self, **kwargs): optimizer_config.cur_step += 1 optimizer_config.loss_value = None - @remote_function(dispatch='slice_dp', collect='flatten') + @remote_function(dispatch='slice_dp', collect=collect_tensor_dict) def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs): """Do forward, calculate loss, and backward. From 7653caf886b7d78ae9ccae43795f6b6ad43f4cbb Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 18 Mar 2026 18:37:23 +0800 Subject: [PATCH 44/56] lint code --- cookbook/megatron/tp.py | 8 ++-- cookbook/rl/gkd_off_policy.py | 12 +++--- cookbook/rl/gkd_on_policy.py | 14 ++++--- .../DeviceMesh-and-DeviceGroup.md | 1 + .../DeviceMesh\345\222\214DeviceGroup.md" | 1 + src/twinkle/checkpoint_engine/manager.py | 1 - src/twinkle/loss/gkd.py | 4 +- src/twinkle/model/megatron/megatron.py | 2 +- src/twinkle/preprocessor/__init__.py | 2 +- src/twinkle/preprocessor/llm.py | 37 ++++++------------- .../sampler/vllm_sampler/vllm_sampler.py | 19 ++++++---- 11 files changed, 48 insertions(+), 53 deletions(-) diff --git a/cookbook/megatron/tp.py b/cookbook/megatron/tp.py index 8cbd7b0b..ee457fe7 100644 --- a/cookbook/megatron/tp.py +++ b/cookbook/megatron/tp.py @@ -19,7 +19,7 @@ def eval(model): # 100 Samples dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100))) - dataset.set_template('Template', model_id='ms://Qwen/Qwen3-0.6B') + dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B') dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) dataset.encode() dataloader = DataLoader(dataset=dataset, batch_size=16) @@ -33,7 +33,7 @@ def train(): # 1000 samples dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000))) # Set template to prepare encoding - dataset.set_template('Template', model_id='ms://Qwen/Qwen3-0.6B') + dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B') # Preprocess the dataset to standard format dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) # Encode dataset @@ -41,7 +41,7 @@ def train(): # Global batch size = 1, dp_size = 1 dataloader = DataLoader(dataset=dataset, batch_size=16) # Use a MegatronModel - model = MegatronModel(model_id='ms://Qwen/Qwen3-0.6B') + model = MegatronModel(model_id='ms://Qwen/Qwen3.5-4B') lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') @@ -51,7 +51,7 @@ def train(): # Add Optimizer for lora `default` model.set_optimizer(optimizer_cls='default', lr=1e-4) # Add LRScheduler for lora `default` - model.set_lr_scheduler(scheduler_cls='default', lr_warmup_steps=0, lr_decay_steps=len(dataloader)) + model.set_lr_scheduler(scheduler_cls='default', lr_warmup_steps=5, lr_decay_steps=len(dataloader)) logger.info(get_device_placement()) # Print the training config logger.info(model.get_train_configs()) diff --git a/cookbook/rl/gkd_off_policy.py b/cookbook/rl/gkd_off_policy.py index 3d234d87..92f82784 100644 --- a/cookbook/rl/gkd_off_policy.py +++ b/cookbook/rl/gkd_off_policy.py @@ -49,7 +49,7 @@ from twinkle.dataset import Dataset, DatasetMeta from twinkle.loss import GKDLoss from twinkle.model import MegatronModel -from twinkle.preprocessor import GSM8KFullProcessor +from twinkle.preprocessor import GSM8KProcessor from twinkle.sampler import vLLMSampler from twinkle.template import Template @@ -71,6 +71,8 @@ GKD_TEMPERATURE = float(os.environ.get('GKD_TEMPERATURE', 1.0)) GKD_TOPK = int(os.environ.get('GKD_TOPK', 64)) ADAPTER_NAME = 'default' +SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem step by step and put ' + 'your final answer within #### ') # ── Dataset ─────────────────────────────────────────────────────────────────── @@ -79,7 +81,7 @@ def create_dataset(): """Full-text dataset with prompt + reference answer for off-policy distillation.""" dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train')) dataset.set_template('Template', model_id=STUDENT_MODEL_ID, max_length=1024) - dataset.map(GSM8KFullProcessor()) + dataset.map(GSM8KProcessor(system=SYSTEM_PROMPT, add_assistant=True)) return dataset @@ -206,13 +208,13 @@ def main(): student_model.forward_backward(inputs=input_data, **teacher_output) student_model.clip_grad_and_step() - if optim_step > 0: + if optim_step > 0 and optim_step % 10 == 0: metric = student_model.calculate_metric(is_training=True) logger.info(f'[Step {optim_step}/{MAX_STEPS}] {metric}') - if optim_step % 50 == 0 and optim_step > 0: + if optim_step > 0 and optim_step % 50 == 0: student_model.save(f'gkd-offpolicy-ckpt-{optim_step}') - + optim_step += 1 student_model.save('gkd-offpolicy-final') diff --git a/cookbook/rl/gkd_on_policy.py b/cookbook/rl/gkd_on_policy.py index e73d46b1..7c1a9af4 100644 --- a/cookbook/rl/gkd_on_policy.py +++ b/cookbook/rl/gkd_on_policy.py @@ -73,7 +73,8 @@ GKD_BETA = float(os.environ.get('GKD_BETA', 0.5)) GKD_TEMPERATURE = float(os.environ.get('GKD_TEMPERATURE', 1.0)) GKD_TOPK = int(os.environ.get('GKD_TOPK', 64)) - +SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem step by step and put ' + 'your final answer within #### ') ADAPTER_NAME = 'default' @@ -83,7 +84,7 @@ def create_dataset(): """Prompt-only dataset; student vLLM will generate completions on-policy.""" dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train')) dataset.set_template('Template', model_id=STUDENT_MODEL_ID, max_length=1024) - dataset.map(GSM8KProcessor()) + dataset.map(GSM8KProcessor(system=SYSTEM_PROMPT)) return dataset @@ -228,17 +229,18 @@ def main(): ) # 5. Student forward + GKD backward - student_model.forward_backward(inputs=input_data, **teacher_output)() + student_model.forward_backward(inputs=input_data, **teacher_output) student_model.clip_grad_and_step() - optim_step += 1 - if optim_step % 10 == 0: + if optim_step > 0 and optim_step % 10 == 0: metric = student_model.calculate_metric(is_training=True) logger.info(f'[Step {optim_step}/{MAX_STEPS}] {metric}') - if optim_step % 50 == 0: + if optim_step > 0 and optim_step % 50 == 0: student_model.save(f'gkd-onpolicy-ckpt-{optim_step}') + optim_step += 1 + student_model.save('gkd-onpolicy-final') logger.info('GKD on-policy training completed.') diff --git a/docs/source_en/Components/Training Middleware/DeviceMesh-and-DeviceGroup.md b/docs/source_en/Components/Training Middleware/DeviceMesh-and-DeviceGroup.md index 69dfb41f..eed7766f 100644 --- a/docs/source_en/Components/Training Middleware/DeviceMesh-and-DeviceGroup.md +++ b/docs/source_en/Components/Training Middleware/DeviceMesh-and-DeviceGroup.md @@ -52,6 +52,7 @@ actor = MegatronModel(..., device_mesh=actor_device_mesh, remote_group=...) for data in dataloader: sampler_output = sampler.sample(data) + ... model_output = actor.forward(sampler_output) ``` diff --git "a/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/DeviceMesh\345\222\214DeviceGroup.md" "b/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/DeviceMesh\345\222\214DeviceGroup.md" index 5532ac89..1db059e5 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/DeviceMesh\345\222\214DeviceGroup.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/DeviceMesh\345\222\214DeviceGroup.md" @@ -52,6 +52,7 @@ actor = MegatronModel(..., device_mesh=actor_device_mesh, remote_group=...) for data in dataloader: sampler_output = sampler.sample(data) + ... model_output = actor.forward(sampler_output) ``` diff --git a/src/twinkle/checkpoint_engine/manager.py b/src/twinkle/checkpoint_engine/manager.py index 89c2ce0a..882a7b5b 100644 --- a/src/twinkle/checkpoint_engine/manager.py +++ b/src/twinkle/checkpoint_engine/manager.py @@ -96,7 +96,6 @@ def sync_weights(self, merge_and_sync=True): Returns: None """ - start_time = time.time() model_metadata = self.model.prepare_checkpoint_engine([True] + [False] * (self.model.device_mesh.world_size - 1)) self.sampler.prepare_checkpoint_engine(False) diff --git a/src/twinkle/loss/gkd.py b/src/twinkle/loss/gkd.py index 3de9179b..399b7d1f 100644 --- a/src/twinkle/loss/gkd.py +++ b/src/twinkle/loss/gkd.py @@ -205,10 +205,10 @@ def _generalized_jsd_loss( del s_chunk, t_chunk if beta == 0: - # Forward KL: KL(S || T) + # Forward KL: KL(T || S) jsd_chunk = F.kl_div(s_log_probs, t_log_probs, reduction='none', log_target=True) elif beta == 1: - # Reverse KL: KL(T || S) + # Reverse KL: KL(S || T) jsd_chunk = F.kl_div(t_log_probs, s_log_probs, reduction='none', log_target=True) else: # Generalised JSD: β·KL(T||M) + (1-β)·KL(S||M) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index bf78cf99..3dd7f8bf 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -364,7 +364,7 @@ def calculate_loss(self, **kwargs): @remote_function() def backward(self, **kwargs): raise NotImplementedError('Megatron only supports `forward_backward` and `forward_only`') - + @remote_function(collect='first', lazy_collect=False) def get_lr(self): return self.optimizer_group['default']._get_lr() diff --git a/src/twinkle/preprocessor/__init__.py b/src/twinkle/preprocessor/__init__.py index 7234a60a..13b52d99 100644 --- a/src/twinkle/preprocessor/__init__.py +++ b/src/twinkle/preprocessor/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .base import DataFilter, Preprocessor from .llm import (AlpacaProcessor, CompetitionMathGRPOProcessor, CompetitionMathProcessor, CountdownProcessor, - GSM8KFullProcessor, GSM8KProcessor, SelfCognitionProcessor) + GSM8KProcessor, SelfCognitionProcessor) diff --git a/src/twinkle/preprocessor/llm.py b/src/twinkle/preprocessor/llm.py index 1b763a6a..a451e90c 100644 --- a/src/twinkle/preprocessor/llm.py +++ b/src/twinkle/preprocessor/llm.py @@ -122,8 +122,16 @@ class GSM8KProcessor(Preprocessor): Extracts the ground truth number and stores it in user_data for reward. Only includes system + user messages; assistant response is generated on-policy. """ + system_prompt = ('You are a helpful math assistant. Solve the problem step by step. ' + 'Show your reasoning in tags, then give the final ' + 'numerical answer after ####.\n' + 'For example:\n ... reasoning ... \n#### 42') - system_prompt = ('You are a helpful math assistant. Solve the problem step by step and put your final answer within #### ') + def __init__(self, system=None, add_assistant=False): + self.system = system + if self.system is None: + self.system = self.system_prompt + self.add_assistant = add_assistant def extract_ground_truth(self, answer_str: str) -> str: """Extract the number after '####' from GSM8K answer.""" @@ -144,32 +152,11 @@ def preprocess(self, row) -> Trajectory: ground_truth = self.extract_ground_truth(answer) messages = [ - Message(role='system', content=self.system_prompt), - Message(role='user', content=question), - ] - return Trajectory( - messages=messages, - user_data=[('ground_truth', ground_truth)], - ) - - -class GSM8KFullProcessor(GSM8KProcessor): - """Preprocessor for GSM8K dataset (full trajectory, for off-policy distillation). - - Includes system + user + assistant messages with the reference answer. - Used when training on existing responses (off-policy) rather than generating new ones. - """ - - def preprocess(self, row) -> Trajectory: - question = row['question'] - answer = row.get('answer', '') - ground_truth = self.extract_ground_truth(answer) - - messages = [ - Message(role='system', content=self.system_prompt), + Message(role='system', content=self.system), Message(role='user', content=question), - Message(role='assistant', content=answer), ] + if self.add_assistant: + messages.append(Message(role='assistant', content=answer)) return Trajectory( messages=messages, user_data=[('ground_truth', ground_truth)], diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py index ec991ed1..440160af 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py @@ -21,12 +21,11 @@ """ import asyncio import atexit +import numpy as np import os import threading from typing import Any, Dict, List, Optional, Union -import numpy as np - from twinkle import DeviceMesh, get_logger, remote_class, remote_function, requires from twinkle.checkpoint_engine import CheckpointEngineMixin from twinkle.data_format import InputFeature, SampledSequence, SampleResponse, SamplingParams, Trajectory @@ -144,7 +143,10 @@ async def _create_engine_async(self, engine_cls, model_id, engine_kwargs): """Create engine in async context to ensure output_handler starts correctly.""" return engine_cls(model_id=model_id, **engine_kwargs) - def encode_trajectory_for_vllm(self, trajectory: Trajectory, adapter_name: str = '', add_generation_prompt=True) -> InputFeature: + def encode_trajectory_for_vllm(self, + trajectory: Trajectory, + adapter_name: str = '', + add_generation_prompt=True) -> InputFeature: """Encode trajectory for vLLM - does not expand image tokens. Args: @@ -256,8 +258,7 @@ async def _sample_single( logprobs=seq.logprobs, decoded=self.template.decode(seq.tokens), new_input_feature=_convert_ndarray_to_list( - self.template.concat_input_feature(feat, seq.tokens) - ), + self.template.concat_input_feature(feat, seq.tokens)), ) for seq in response.sequences ], prompt_logprobs=response.prompt_logprobs, @@ -272,7 +273,7 @@ async def _sample_single( ) for seq in response.sequences ], prompt_logprobs=response.prompt_logprobs, - topk_prompt_logprobs=response.topk_prompt_logprobs) + topk_prompt_logprobs=response.topk_prompt_logprobs) @remote_function(dispatch='slice_dp', collect='flatten', lazy_collect=False) def sample( @@ -322,12 +323,14 @@ def sample( if sampling_params.max_tokens == 0: sampling_params.max_tokens = 1 logprobs_only = True - + if is_trajectory: template = self.template assert template is not None, \ 'Use set_template to add a template when trying to input Trajectory' - encoded_inputs = [self.encode_trajectory_for_vllm(traj, adapter_name, not logprobs_only) for traj in inputs_list] + encoded_inputs = [ + self.encode_trajectory_for_vllm(traj, adapter_name, not logprobs_only) for traj in inputs_list + ] else: encoded_inputs = inputs_list From 5c36715ef51d43c3d670214979aa77533fe5e20e Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 18 Mar 2026 19:44:41 +0800 Subject: [PATCH 45/56] fix --- README.md | 3 +- README_ZH.md | 1 + cookbook/client/twinkle/self_host/grpo.py | 12 ++--- cookbook/client/twinkle/self_host/sample.py | 13 ++--- cookbook/rl/grpo.py | 11 ++-- .../Components/Advantage/GRPOAdvantage.md | 17 +++---- .../DeviceMesh-and-DeviceGroup.md | 3 +- .../GRPOAdvantage.md" | 23 ++++----- .../DeviceMesh\345\222\214DeviceGroup.md" | 3 +- src/twinkle/sampler/base.py | 2 +- .../sampler/torch_sampler/torch_sampler.py | 4 +- .../sampler/vllm_sampler/vllm_sampler.py | 4 +- src/twinkle/server/sampler/tinker_handlers.py | 35 ++++++------- .../server/sampler/twinkle_handlers.py | 51 ++++++++++--------- src/twinkle_client/sampler/vllm_sampler.py | 6 +-- tests/sampler/align_swift.py | 8 +-- tests/sampler/test_30b_weight_sync.py | 23 +++++---- tests/sampler/test_megatron_weight_sync.py | 13 ++--- tests/sampler/test_weight_sync.py | 13 ++--- 19 files changed, 128 insertions(+), 117 deletions(-) diff --git a/README.md b/README.md index b1fa32ad..829f1134 100644 --- a/README.md +++ b/README.md @@ -101,9 +101,8 @@ Or use ModelScope's [official image](https://www.modelscope.cn/docs/intro/enviro ## Changelog +- 🎉2026-03-19 Support GKD training ,please refer to this [cookbook](cookbook/rl/gkd_on_policy.py). - 🎉2026-02-13 Initial version of Twinkle✨ released, including SFT/PT/RL support for text models. -We also made available serverless training capabilities on [ModelScope](https://modelscope.cn) via -Tinker-compatible APIs. ## Training as a Service on ModelScope diff --git a/README_ZH.md b/README_ZH.md index 75c69624..1fc8267a 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -91,6 +91,7 @@ Twinkle✨支持相同的算法接口运行在单GPU、torchrun多机、Ray、Cl ## 更新日志 +🎉2026-03-19 支持GKD蒸馏能力,参考[cookbook](cookbook/rl/gkd_on_policy.py)。 🎉2026-02-13 Twinkle✨ 初始版本发布,支持文本模型的SFT/PT/RL训练。我们还通过兼容Tinker的API,在魔搭社区上提供了无服务器训练功能。 ## ModelScope 的训练服务 diff --git a/cookbook/client/twinkle/self_host/grpo.py b/cookbook/client/twinkle/self_host/grpo.py index 8291fb91..f2a24e83 100644 --- a/cookbook/client/twinkle/self_host/grpo.py +++ b/cookbook/client/twinkle/self_host/grpo.py @@ -38,7 +38,6 @@ from twinkle_client.dataset import Dataset from twinkle_client.model import MultiLoraTransformersModel from twinkle_client.sampler import vLLMSampler -from twinkle.preprocessor.llm import GSM8KProcessor logger = get_logger() @@ -153,7 +152,7 @@ def train(): logger.info(f'Step {step}: Saved weights to {current_adapter_uri}') # ========== 2. Sample completions ========== - sample_response = sampler.sample( + sample_responses = sampler.sample( inputs=prompts, sampling_params=sampling_params, adapter_uri=current_adapter_uri, @@ -164,10 +163,11 @@ def train(): all_old_logps: List[List[float]] = [] all_completion_lengths: List[int] = [] - for sequence in sample_response.sequences: - all_input_data.append(sequence.new_input_feature) - all_old_logps.append(sequence.logprobs) - all_completion_lengths.append(len(sequence.tokens)) + for sample_response in sample_responses: + for sequence in sample_response.sequences: + all_input_data.append(sequence.new_input_feature) + all_old_logps.append(sequence.logprobs) + all_completion_lengths.append(len(sequence.tokens)) # ========== 3. Compute rewards ========== diff --git a/cookbook/client/twinkle/self_host/sample.py b/cookbook/client/twinkle/self_host/sample.py index d800b635..fdb0e17e 100644 --- a/cookbook/client/twinkle/self_host/sample.py +++ b/cookbook/client/twinkle/self_host/sample.py @@ -73,7 +73,7 @@ def sample(): # - sampling_params: controls generation behavior # - adapter_uri: optional LoRA adapter path for fine-tuned inference # - num_samples: number of completions per prompt - response = sampler.sample( + responses = sampler.sample( inputs=[trajectory] * num_prompts, sampling_params=sampling_params, adapter_uri=ADAPTER_URI, @@ -83,12 +83,13 @@ def sample(): # Step 8: Decode and print the results tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) - logger.info(f'Generated {len(response.sequences)} sequences ' - f'({num_prompts} prompts x {num_samples} samples)') + for response in responses: + logger.info(f'Generated {len(response.sequences)} sequences ' + f'({num_prompts} prompts x {num_samples} samples)') - for i, seq in enumerate(response.sequences): - text = tokenizer.decode(seq.tokens, skip_special_tokens=True) - logger.info(f'Sequence {i}:\n {text}\n') + for i, seq in enumerate(response.sequences): + text = tokenizer.decode(seq.tokens, skip_special_tokens=True) + logger.info(f'Sequence {i}:\n {text}\n') if __name__ == '__main__': diff --git a/cookbook/rl/grpo.py b/cookbook/rl/grpo.py index 590ca719..67434348 100644 --- a/cookbook/rl/grpo.py +++ b/cookbook/rl/grpo.py @@ -125,7 +125,7 @@ def main(): global_prompts = batch if isinstance(batch, list) else [batch] ckpt_manager.sync_weights(merge_and_sync=False) sampler.reset_prefix_cache() - sample_response = sampler.sample( + sample_responses = sampler.sample( global_prompts*NUM_GENERATIONS, sampling_params, num_samples=1, @@ -135,10 +135,11 @@ def main(): all_old_logps: List[List[float]] = [] all_completion_lengths: List[int] = [] - for sequence in sample_response.sequences: - all_input_data.append(sequence.new_input_feature) - all_old_logps.append(sequence.logprobs) - all_completion_lengths.append(len(sequence.tokens)) + for sample_response in sample_responses: + for sequence in sample_response.sequences: + all_input_data.append(sequence.new_input_feature) + all_old_logps.append(sequence.logprobs) + all_completion_lengths.append(len(sequence.tokens)) total_rewards, format_rewards, accuracy_rewards = compute_rewards( all_input_data ) diff --git a/docs/source_en/Components/Advantage/GRPOAdvantage.md b/docs/source_en/Components/Advantage/GRPOAdvantage.md index fb92e7b4..accb24bb 100644 --- a/docs/source_en/Components/Advantage/GRPOAdvantage.md +++ b/docs/source_en/Components/Advantage/GRPOAdvantage.md @@ -38,28 +38,27 @@ Using the advantage function in GRPO training: from twinkle.advantage import GRPOAdvantage from twinkle.model import TransformersModel from twinkle.sampler import vLLMSampler -from twinkle.reward import MathReward # Create components actor = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B') sampler = vLLMSampler(model_id='ms://Qwen/Qwen3.5-4B') -reward_fn = MathReward() +reward_fn = ... advantage_fn = GRPOAdvantage() # Training loop for batch in dataloader: - # 1. Sample generation - response = sampler.sample(batch, num_samples=4) + # Sample generation + sample_response = sampler.sample(batch, num_samples=4) + input_data = [seq.new_input_feature for response in sample_response for seq in response.sequences] + ... + rewards = reward_fn(...) - # 2. Calculate rewards - rewards = reward_fn(response.trajectories, batch.ground_truths) - - # 3. Calculate advantages + # Calculate advantages advantages = advantage_fn(rewards, num_generations=4) # 4. Policy optimization loss = actor.forward_backward( - inputs=response.inputs, + inputs=input_data, advantages=advantages ) actor.clip_grad_and_step() diff --git a/docs/source_en/Components/Training Middleware/DeviceMesh-and-DeviceGroup.md b/docs/source_en/Components/Training Middleware/DeviceMesh-and-DeviceGroup.md index eed7766f..169adb86 100644 --- a/docs/source_en/Components/Training Middleware/DeviceMesh-and-DeviceGroup.md +++ b/docs/source_en/Components/Training Middleware/DeviceMesh-and-DeviceGroup.md @@ -52,8 +52,9 @@ actor = MegatronModel(..., device_mesh=actor_device_mesh, remote_group=...) for data in dataloader: sampler_output = sampler.sample(data) + input_data = [seq.new_input_feature for response in sampler_output for seq in response.sequences] ... - model_output = actor.forward(sampler_output) + model_output = actor.forward(input_data) ``` We analyze the data transfer situation using the pseudo-code above. diff --git "a/docs/source_zh/\347\273\204\344\273\266/\344\274\230\345\212\277/GRPOAdvantage.md" "b/docs/source_zh/\347\273\204\344\273\266/\344\274\230\345\212\277/GRPOAdvantage.md" index 9dc0635d..d167b353 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\344\274\230\345\212\277/GRPOAdvantage.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/\344\274\230\345\212\277/GRPOAdvantage.md" @@ -38,28 +38,27 @@ GRPO 将样本分组(每组对应一个 prompt 的多个生成),然后在组内: from twinkle.advantage import GRPOAdvantage from twinkle.model import TransformersModel from twinkle.sampler import vLLMSampler -from twinkle.reward import MathReward -# 创建组件 +# Create components actor = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B') sampler = vLLMSampler(model_id='ms://Qwen/Qwen3.5-4B') -reward_fn = MathReward() +reward_fn = ... advantage_fn = GRPOAdvantage() -# 训练循环 +# Training loop for batch in dataloader: - # 1. 采样生成 - response = sampler.sample(batch, num_samples=4) + # Sample generation + sample_response = sampler.sample(batch, num_samples=4) + input_data = [seq.new_input_feature for response in sample_response for seq in response.sequences] + ... + rewards = reward_fn(...) - # 2. 计算奖励 - rewards = reward_fn(response.trajectories, batch.ground_truths) - - # 3. 计算优势 + # Calculate advantages advantages = advantage_fn(rewards, num_generations=4) - # 4. 策略优化 + # 4. Policy optimization loss = actor.forward_backward( - inputs=response.inputs, + inputs=input_data, advantages=advantages ) actor.clip_grad_and_step() diff --git "a/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/DeviceMesh\345\222\214DeviceGroup.md" "b/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/DeviceMesh\345\222\214DeviceGroup.md" index 1db059e5..00ec1f30 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/DeviceMesh\345\222\214DeviceGroup.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/DeviceMesh\345\222\214DeviceGroup.md" @@ -52,8 +52,9 @@ actor = MegatronModel(..., device_mesh=actor_device_mesh, remote_group=...) for data in dataloader: sampler_output = sampler.sample(data) + input_data = [seq.new_input_feature for response in sampler_output for seq in response.sequences] ... - model_output = actor.forward(sampler_output) + model_output = actor.forward(input_data) ``` 我们以上面的伪代码来分析数据传递情况。 diff --git a/src/twinkle/sampler/base.py b/src/twinkle/sampler/base.py index 1ceb1935..4aaeb3e0 100644 --- a/src/twinkle/sampler/base.py +++ b/src/twinkle/sampler/base.py @@ -24,7 +24,7 @@ def sample( adapter_name: str = '', *, num_samples: int = 1, - ) -> SampleResponse: + ) -> List[SampleResponse]: """Sample responses for given inputs. Args: diff --git a/src/twinkle/sampler/torch_sampler/torch_sampler.py b/src/twinkle/sampler/torch_sampler/torch_sampler.py index 695033e0..8c7643ac 100644 --- a/src/twinkle/sampler/torch_sampler/torch_sampler.py +++ b/src/twinkle/sampler/torch_sampler/torch_sampler.py @@ -55,7 +55,7 @@ def sample( inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], sampling_params: Optional[Union[SamplingParams, Dict[str, Any]]] = None, adapter_name: str = '', - ) -> SampleResponse: + ) -> List[SampleResponse]: """Sample responses for given inputs. Args: @@ -154,4 +154,4 @@ def sample( logprobs=seq_logprobs, )) - return SampleResponse(sequences=all_sequences) + return [SampleResponse(sequences=all_sequences)] diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py index 440160af..e4705017 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py @@ -219,7 +219,7 @@ async def _sample_single( lora_request: Optional[Any] = None, *, logprobs_only: bool = False, - ) -> List[SampledSequence]: + ) -> SampleResponse: """Sample a single input asynchronously. Args: @@ -231,7 +231,7 @@ async def _sample_single( logprobs_only: Only return logprobs (no generated tokens). Returns: - List of num_samples SampledSequence objects. + A SampleResponse object """ input_ids = feat['input_ids'] if hasattr(input_ids, 'tolist'): diff --git a/src/twinkle/server/sampler/tinker_handlers.py b/src/twinkle/server/sampler/tinker_handlers.py index 4cd574be..16b75040 100644 --- a/src/twinkle/server/sampler/tinker_handlers.py +++ b/src/twinkle/server/sampler/tinker_handlers.py @@ -79,31 +79,32 @@ async def _do_sample(): stop=body.sampling_params.stop, ) - response = self.sampler.sample( + responses = self.sampler.sample( inputs=[prompt_inputs] * body.num_samples, sampling_params=sampling_params, adapter_path=adapter_uri, ) - # Convert twinkle SampleResponse to tinker types tinker_sequences = [] - for seq in response.sequences: - logprobs = None - if seq.logprobs is not None: - if any(lp is None for lp in seq.logprobs): - logprobs = None - else: - logprobs = list(seq.logprobs) - tinker_sequences.append( - types.SampledSequence( - stop_reason=seq.stop_reason, - tokens=list(seq.tokens), - logprobs=logprobs, - )) + for response in responses: + # Convert twinkle SampleResponse to tinker types + for seq in response.sequences: + logprobs = None + if seq.logprobs is not None: + if any(lp is None for lp in seq.logprobs): + logprobs = None + else: + logprobs = list(seq.logprobs) + tinker_sequences.append( + types.SampledSequence( + stop_reason=seq.stop_reason, + tokens=list(seq.tokens), + logprobs=logprobs, + )) return types.SampleResponse( sequences=tinker_sequences, - prompt_logprobs=response.prompt_logprobs, - topk_prompt_logprobs=response.topk_prompt_logprobs, + prompt_logprobs=responses[0].prompt_logprobs, + topk_prompt_logprobs=responses[0].topk_prompt_logprobs, ) except Exception: logger.error(traceback.format_exc()) diff --git a/src/twinkle/server/sampler/twinkle_handlers.py b/src/twinkle/server/sampler/twinkle_handlers.py index a31f4046..6ce03eae 100644 --- a/src/twinkle/server/sampler/twinkle_handlers.py +++ b/src/twinkle/server/sampler/twinkle_handlers.py @@ -8,7 +8,7 @@ import traceback from fastapi import Depends, FastAPI, HTTPException, Request -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Callable, List, Optional if TYPE_CHECKING: from .app import SamplerManagement @@ -60,8 +60,9 @@ def create(request: Request, self: SamplerManagement = Depends(self_fn)) -> type return types.CreateResponse() @app.post('/twinkle/sample', response_model=types.SampleResponseModel) - def sample(request: Request, body: types.SampleRequest, - self: SamplerManagement = Depends(self_fn)) -> types.SampleResponseModel: + def sample( + request: Request, body: types.SampleRequest, + self: SamplerManagement = Depends(self_fn)) -> list[types.SampleResponseModel]: """Sample completions from the model. Supports Trajectory or InputFeature inputs, with optional LoRA adapter. @@ -99,32 +100,36 @@ def sample(request: Request, body: types.SampleRequest, params = SamplingParams.from_dict(body.sampling_params) # Call sampler - response = self.sampler.sample( + responses = self.sampler.sample( inputs, params, adapter_name=full_adapter_name, adapter_path=adapter_path, num_samples=body.num_samples, ) - if callable(response): - response = response() - - sequences = [ - types.SampledSequenceModel( - stop_reason=seq.stop_reason, - tokens=list(seq.tokens), - logprobs=list(seq.logprobs) if seq.logprobs is not None else None, - decoded=seq.decoded, - new_input_feature=_serialize_input_feature(seq.new_input_feature) - if seq.new_input_feature is not None else None, - ) for seq in response.sequences - ] - - return types.SampleResponseModel( - sequences=sequences, - prompt_logprobs=response.prompt_logprobs, - topk_prompt_logprobs=response.topk_prompt_logprobs, - ) + if callable(responses): + responses = responses() + + sample_models = [] + for response in responses: + sequences = [ + types.SampledSequenceModel( + stop_reason=seq.stop_reason, + tokens=list(seq.tokens), + logprobs=list(seq.logprobs) if seq.logprobs is not None else None, + decoded=seq.decoded, + new_input_feature=_serialize_input_feature(seq.new_input_feature) + if seq.new_input_feature is not None else None, + ) for seq in response.sequences + ] + + sample_models.append( + types.SampleResponseModel( + sequences=sequences, + prompt_logprobs=response.prompt_logprobs, + topk_prompt_logprobs=response.topk_prompt_logprobs, + )) + return sample_models except Exception: logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=traceback.format_exc()) diff --git a/src/twinkle_client/sampler/vllm_sampler.py b/src/twinkle_client/sampler/vllm_sampler.py index a19984c3..ea5e8767 100644 --- a/src/twinkle_client/sampler/vllm_sampler.py +++ b/src/twinkle_client/sampler/vllm_sampler.py @@ -57,7 +57,7 @@ def sample( adapter_name: str = '', adapter_uri: Optional[str] = None, num_samples: int = 1, - ) -> SampleResponseModel: + ) -> List[SampleResponseModel]: """Sample from the model. Args: @@ -68,7 +68,7 @@ def sample( num_samples: Number of completions to generate per prompt. Returns: - SampleResponseModel with 'sequences' list, each containing tokens, logprobs, stop_reason. + A list of sampleResponseModel with 'sequences' list, each containing tokens, logprobs, stop_reason. """ json_data = { 'inputs': inputs, @@ -84,7 +84,7 @@ def sample( json_data=json_data ) response.raise_for_status() - return SampleResponseModel(**response.json()) + return [SampleResponseModel(**r) for r in response.json()] def set_template(self, template_cls: str, adapter_name: str = '', **kwargs) -> SetTemplateResponse: """Set the template for encoding trajectories.""" diff --git a/tests/sampler/align_swift.py b/tests/sampler/align_swift.py index dc33ff36..a0b26882 100644 --- a/tests/sampler/align_swift.py +++ b/tests/sampler/align_swift.py @@ -121,7 +121,7 @@ def test_llm_torch_sampler(): trajectory = Trajectory(messages=LLM_MESSAGES) sampling_params = SamplingParams(max_tokens=128, temperature=0) resp = sampler.sample([trajectory], sampling_params=sampling_params) - tokens = resp.sequences[0].tokens + tokens = resp[0].sequences[0].tokens twinkle_response = sampler.template.decode(tokens, skip_special_tokens=True) del sampler clean_cache() @@ -157,7 +157,7 @@ def test_llm_vllm_sampler(): resp = sampler.sample([trajectory] * 16, sampling_params=sampling_params) end_time = time.time() print(f'Twinkle inference time: {end_time - st_time} seconds') - tokens = resp.sequences[0].tokens + tokens = resp[0].sequences[0].tokens twinkle_response = sampler.template.decode(tokens, skip_special_tokens=True) del sampler clean_cache() @@ -271,7 +271,7 @@ def test_mllm_torch_sampler(): trajectory = Trajectory(messages=MLLM_MESSAGES, images=MLLM_IMAGES) sampling_params = SamplingParams(max_tokens=128, temperature=0) resp = sampler.sample([trajectory], sampling_params=sampling_params) - tokens = resp.sequences[0].tokens + tokens = resp[0].sequences[0].tokens twinkle_response = sampler.template.decode(tokens, skip_special_tokens=True) del sampler clean_cache() @@ -299,7 +299,7 @@ def test_mllm_vllm_sampler(): trajectory = Trajectory(messages=MLLM_MESSAGES, images=MLLM_IMAGES) sampling_params = SamplingParams(max_tokens=128, temperature=0) resp = sampler.sample([trajectory], sampling_params=sampling_params) - tokens = resp.sequences[0].tokens + tokens = resp[0].sequences[0].tokens twinkle_response = sampler.template.decode(tokens, skip_special_tokens=True) del sampler clean_cache() diff --git a/tests/sampler/test_30b_weight_sync.py b/tests/sampler/test_30b_weight_sync.py index 9eb6b976..6e85462b 100644 --- a/tests/sampler/test_30b_weight_sync.py +++ b/tests/sampler/test_30b_weight_sync.py @@ -161,17 +161,18 @@ def test_weight_sync(model_gpus: int = 2, sampler_gpus: int = 1, vllm_tp: int = # Quick sample to verify model works log('\n--- Sampling after sync ---') traj = Trajectory(messages=[{'role': 'user', 'content': 'What is 2+2?'}]) - response = sampler.sample(traj, SamplingParams(max_tokens=32, temperature=0.0)) - if callable(response): - response = response() - if response and response.sequences: - tokens = response.sequences[0].tokens - if hasattr(tokens, 'tolist'): - tokens = tokens.tolist() - from modelscope import AutoTokenizer - tok = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - text = tok.decode(tokens, skip_special_tokens=True) - log(f" Output: '{text[:200]}'") + responses = sampler.sample(traj, SamplingParams(max_tokens=32, temperature=0.0)) + if callable(responses): + responses = responses() + for response in responses: + if response and response.sequences: + tokens = response.sequences[0].tokens + if hasattr(tokens, 'tolist'): + tokens = tokens.tolist() + from modelscope import AutoTokenizer + tok = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + text = tok.decode(tokens, skip_special_tokens=True) + log(f" Output: '{text[:200]}'") log('\n--- PASS: Weight sync completed without OOM or hang ---') log(f' Base sync: {base_time:.2f}s, LoRA sync: {lora_time:.2f}s') diff --git a/tests/sampler/test_megatron_weight_sync.py b/tests/sampler/test_megatron_weight_sync.py index a8021f7a..2d32b5f3 100644 --- a/tests/sampler/test_megatron_weight_sync.py +++ b/tests/sampler/test_megatron_weight_sync.py @@ -199,12 +199,13 @@ def test_megatron_weight_sync( # ── Helper: sample one prompt ───────────────────────────────────── def do_sample(prompt: str, max_tokens: int = 32) -> str: traj = Trajectory(messages=[{'role': 'user', 'content': prompt}]) - response = wait_result(sampler.sample(traj, SamplingParams(max_tokens=max_tokens, temperature=0.0))) - if response and response.sequences: - tokens = response.sequences[0].tokens - if hasattr(tokens, 'tolist'): - tokens = tokens.tolist() - return tokenizer.decode(tokens, skip_special_tokens=True) + responses = wait_result(sampler.sample(traj, SamplingParams(max_tokens=max_tokens, temperature=0.0))) + for response in responses: + if response and response.sequences: + tokens = response.sequences[0].tokens + if hasattr(tokens, 'tolist'): + tokens = tokens.tolist() + return tokenizer.decode(tokens, skip_special_tokens=True) return '' # ── Sample BEFORE sync (dummy weights → garbage) ────────────────── diff --git a/tests/sampler/test_weight_sync.py b/tests/sampler/test_weight_sync.py index d22662af..af63cdb4 100644 --- a/tests/sampler/test_weight_sync.py +++ b/tests/sampler/test_weight_sync.py @@ -178,12 +178,13 @@ def test_standalone_weight_sync(model_gpus: int = 1, sampler_gpus: int = 1): # ── Helper: sample one prompt ───────────────────────────────────── def do_sample(prompt: str, max_tokens: int = 32) -> str: traj = Trajectory(messages=[{'role': 'user', 'content': prompt}]) - response = wait_result(sampler.sample(traj, SamplingParams(max_tokens=max_tokens, temperature=0.0))) - if response and response.sequences: - tokens = response.sequences[0].tokens - if hasattr(tokens, 'tolist'): - tokens = tokens.tolist() - return tokenizer.decode(tokens, skip_special_tokens=True) + responses = wait_result(sampler.sample(traj, SamplingParams(max_tokens=max_tokens, temperature=0.0))) + for response in responses: + if response and response.sequences: + tokens = response.sequences[0].tokens + if hasattr(tokens, 'tolist'): + tokens = tokens.tolist() + return tokenizer.decode(tokens, skip_special_tokens=True) return '' # ── Sample BEFORE sync (dummy weights → garbage) ────────────────── From 5dc401dab45c4b8c09a9232d9259b563cae471b5 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 18 Mar 2026 19:46:45 +0800 Subject: [PATCH 46/56] fix docs --- docs/source_en/Components/Sampler/Sampler.md | 2 +- docs/source_en/Components/Sampler/TorchSampler.md | 2 +- docs/source_en/Components/Sampler/vLLMSampler.md | 4 ++-- .../\351\207\207\346\240\267\345\231\250/Sampler.md" | 2 +- .../\351\207\207\346\240\267\345\231\250/TorchSampler.md" | 2 +- .../\351\207\207\346\240\267\345\231\250/vLLMSampler.md" | 4 ++-- 6 files changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/source_en/Components/Sampler/Sampler.md b/docs/source_en/Components/Sampler/Sampler.md index 23f46dae..d63be7de 100644 --- a/docs/source_en/Components/Sampler/Sampler.md +++ b/docs/source_en/Components/Sampler/Sampler.md @@ -15,7 +15,7 @@ class Sampler(ABC): adapter_name: str = '', *, num_samples: int = 1, - ) -> SampleResponse: + ) -> List[SampleResponse]: """Sample from given inputs""" ... diff --git a/docs/source_en/Components/Sampler/TorchSampler.md b/docs/source_en/Components/Sampler/TorchSampler.md index 6076d801..a39c5571 100644 --- a/docs/source_en/Components/Sampler/TorchSampler.md +++ b/docs/source_en/Components/Sampler/TorchSampler.md @@ -13,7 +13,7 @@ sampler = TorchSampler( device_mesh=DeviceMesh.from_sizes(dp_size=1), ) -response = sampler.sample(trajectories, sampling_params=params) +responses = sampler.sample(trajectories, sampling_params=params) ``` ## Features diff --git a/docs/source_en/Components/Sampler/vLLMSampler.md b/docs/source_en/Components/Sampler/vLLMSampler.md index 95e7d283..d015d566 100644 --- a/docs/source_en/Components/Sampler/vLLMSampler.md +++ b/docs/source_en/Components/Sampler/vLLMSampler.md @@ -28,7 +28,7 @@ params = SamplingParams( ) # Perform sampling -response = sampler.sample( +responses = sampler.sample( trajectories, sampling_params=params, adapter_name='my_lora', @@ -66,7 +66,7 @@ sampler = vLLMSampler( ) # sample method executes in remote worker -response = sampler.sample(trajectories, sampling_params=params) +responses = sampler.sample(trajectories, sampling_params=params) ``` ## Environment Variables diff --git "a/docs/source_zh/\347\273\204\344\273\266/\351\207\207\346\240\267\345\231\250/Sampler.md" "b/docs/source_zh/\347\273\204\344\273\266/\351\207\207\346\240\267\345\231\250/Sampler.md" index 3fb91acc..1cb12eaf 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\351\207\207\346\240\267\345\231\250/Sampler.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/\351\207\207\346\240\267\345\231\250/Sampler.md" @@ -15,7 +15,7 @@ class Sampler(ABC): adapter_name: str = '', *, num_samples: int = 1, - ) -> SampleResponse: + ) -> List[SampleResponse]: """对给定输入进行采样""" ... diff --git "a/docs/source_zh/\347\273\204\344\273\266/\351\207\207\346\240\267\345\231\250/TorchSampler.md" "b/docs/source_zh/\347\273\204\344\273\266/\351\207\207\346\240\267\345\231\250/TorchSampler.md" index bcdce5b9..e6a96902 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\351\207\207\346\240\267\345\231\250/TorchSampler.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/\351\207\207\346\240\267\345\231\250/TorchSampler.md" @@ -13,7 +13,7 @@ sampler = TorchSampler( device_mesh=DeviceMesh.from_sizes(dp_size=1), ) -response = sampler.sample(trajectories, sampling_params=params) +responses = sampler.sample(trajectories, sampling_params=params) ``` ## 特性 diff --git "a/docs/source_zh/\347\273\204\344\273\266/\351\207\207\346\240\267\345\231\250/vLLMSampler.md" "b/docs/source_zh/\347\273\204\344\273\266/\351\207\207\346\240\267\345\231\250/vLLMSampler.md" index 38b4e5be..eeed825f 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\351\207\207\346\240\267\345\231\250/vLLMSampler.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/\351\207\207\346\240\267\345\231\250/vLLMSampler.md" @@ -28,7 +28,7 @@ params = SamplingParams( ) # 进行采样 -response = sampler.sample( +responses = sampler.sample( trajectories, sampling_params=params, adapter_name='my_lora', @@ -66,7 +66,7 @@ sampler = vLLMSampler( ) # sample 方法会在 remote worker 中执行 -response = sampler.sample(trajectories, sampling_params=params) +responses = sampler.sample(trajectories, sampling_params=params) ``` ## 环境变量 From 2a0bfe0f2239cf8af83a367885df5343165c515f Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 18 Mar 2026 20:10:27 +0800 Subject: [PATCH 47/56] fix --- cookbook/client/twinkle/self_host/grpo.py | 2 +- cookbook/client/twinkle/self_host/sample.py | 2 +- cookbook/rl/grpo.py | 3 +-- src/twinkle/server/sampler/twinkle_handlers.py | 2 +- src/twinkle_client/types/sampler.py | 3 +-- 5 files changed, 5 insertions(+), 7 deletions(-) diff --git a/cookbook/client/twinkle/self_host/grpo.py b/cookbook/client/twinkle/self_host/grpo.py index f2a24e83..6d8185ba 100644 --- a/cookbook/client/twinkle/self_host/grpo.py +++ b/cookbook/client/twinkle/self_host/grpo.py @@ -126,6 +126,7 @@ def train(): 'max_tokens': MAX_NEW_TOKENS, 'temperature': TEMPERATURE, 'top_p': 0.95, + 'num_samples': NUM_GENERATIONS, } # Track the current adapter path for sampling @@ -156,7 +157,6 @@ def train(): inputs=prompts, sampling_params=sampling_params, adapter_uri=current_adapter_uri, - num_samples=NUM_GENERATIONS, ) all_input_data: List[Dict[str, Any]] = [] diff --git a/cookbook/client/twinkle/self_host/sample.py b/cookbook/client/twinkle/self_host/sample.py index fdb0e17e..3b02c4ec 100644 --- a/cookbook/client/twinkle/self_host/sample.py +++ b/cookbook/client/twinkle/self_host/sample.py @@ -66,6 +66,7 @@ def sample(): sampling_params = { 'max_tokens': 128, 'temperature': 1.0, + 'num_samples': num_samples, } # Step 7: Call the sampler @@ -77,7 +78,6 @@ def sample(): inputs=[trajectory] * num_prompts, sampling_params=sampling_params, adapter_uri=ADAPTER_URI, - num_samples=num_samples, ) # Step 8: Decode and print the results diff --git a/cookbook/rl/grpo.py b/cookbook/rl/grpo.py index 67434348..f6082668 100644 --- a/cookbook/rl/grpo.py +++ b/cookbook/rl/grpo.py @@ -113,7 +113,7 @@ def main(): advantage_fn = GRPOAdvantage() metrics = CompletionRewardMetric() - sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS) + sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS, num_samples=1) optim_step = 0 logger.info(get_device_placement()) @@ -128,7 +128,6 @@ def main(): sample_responses = sampler.sample( global_prompts*NUM_GENERATIONS, sampling_params, - num_samples=1, ) all_input_data: List[Dict[str, Any]] = [] diff --git a/src/twinkle/server/sampler/twinkle_handlers.py b/src/twinkle/server/sampler/twinkle_handlers.py index 6ce03eae..83136f57 100644 --- a/src/twinkle/server/sampler/twinkle_handlers.py +++ b/src/twinkle/server/sampler/twinkle_handlers.py @@ -98,6 +98,7 @@ def sample( params = None if body.sampling_params: params = SamplingParams.from_dict(body.sampling_params) + params.num_samples = body.num_samples # Call sampler responses = self.sampler.sample( @@ -105,7 +106,6 @@ def sample( params, adapter_name=full_adapter_name, adapter_path=adapter_path, - num_samples=body.num_samples, ) if callable(responses): responses = responses() diff --git a/src/twinkle_client/types/sampler.py b/src/twinkle_client/types/sampler.py index cf370330..a1b579b7 100644 --- a/src/twinkle_client/types/sampler.py +++ b/src/twinkle_client/types/sampler.py @@ -14,11 +14,10 @@ class SampleRequest(BaseModel): """Request body for the /sample endpoint.""" inputs: Any = Field(..., description='List of Trajectory or InputFeature dicts') sampling_params: Optional[Dict[str, Any]] = Field( - None, description='Sampling parameters (max_tokens, temperature, etc.)') + None, description='Sampling parameters (max_tokens, temperature, num_samples, etc.)') adapter_name: str = Field('', description='Adapter name for LoRA inference') adapter_uri: Optional[str] = Field( None, description='Adapter URI (twinkle:// path or local path) for LoRA inference') - num_samples: int = Field(1, description='Number of completions to generate per prompt') class SampledSequenceModel(BaseModel): From a46bd59f07973a78e32e6405193df7b83709d486 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 18 Mar 2026 20:55:37 +0800 Subject: [PATCH 48/56] fix --- cookbook/client/twinkle/self_host/grpo.py | 2 +- cookbook/rl/grpo.py | 6 +++--- docs/source_en/Usage Guide/Introduction-with-Qwen3.5.md | 2 +- docs/source_en/Usage Guide/Quick-Start.md | 2 +- ...5\346\234\200\344\275\263\345\256\236\350\267\265.md" | 2 +- .../\345\277\253\351\200\237\345\274\200\345\247\213.md" | 2 +- src/twinkle/model/megatron/megatron.py | 9 +++++++-- src/twinkle/model/transformers/transformers.py | 9 +++++++++ 8 files changed, 24 insertions(+), 10 deletions(-) diff --git a/cookbook/client/twinkle/self_host/grpo.py b/cookbook/client/twinkle/self_host/grpo.py index 6d8185ba..bfc0f643 100644 --- a/cookbook/client/twinkle/self_host/grpo.py +++ b/cookbook/client/twinkle/self_host/grpo.py @@ -166,7 +166,7 @@ def train(): for sample_response in sample_responses: for sequence in sample_response.sequences: all_input_data.append(sequence.new_input_feature) - all_old_logps.append(sequence.logprobs) + all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs]) all_completion_lengths.append(len(sequence.tokens)) # ========== 3. Compute rewards ========== diff --git a/cookbook/rl/grpo.py b/cookbook/rl/grpo.py index f6082668..e35dc648 100644 --- a/cookbook/rl/grpo.py +++ b/cookbook/rl/grpo.py @@ -20,7 +20,7 @@ logger = get_logger() -MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B') +MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3-0.6B') USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '0'))) MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) @@ -113,7 +113,7 @@ def main(): advantage_fn = GRPOAdvantage() metrics = CompletionRewardMetric() - sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS, num_samples=1) + sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS, num_samples=1, logprobs=1) optim_step = 0 logger.info(get_device_placement()) @@ -137,7 +137,7 @@ def main(): for sample_response in sample_responses: for sequence in sample_response.sequences: all_input_data.append(sequence.new_input_feature) - all_old_logps.append(sequence.logprobs) + all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs]) all_completion_lengths.append(len(sequence.tokens)) total_rewards, format_rewards, accuracy_rewards = compute_rewards( all_input_data diff --git a/docs/source_en/Usage Guide/Introduction-with-Qwen3.5.md b/docs/source_en/Usage Guide/Introduction-with-Qwen3.5.md index 660179e0..4e53972d 100644 --- a/docs/source_en/Usage Guide/Introduction-with-Qwen3.5.md +++ b/docs/source_en/Usage Guide/Introduction-with-Qwen3.5.md @@ -278,7 +278,7 @@ def main(): for sequence in sample_response.sequences: all_input_data.append(sequence.new_input_feature) - all_old_logps.append(sequence.logprobs) + all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs]) all_completion_lengths.append(len(sequence.tokens)) # Compute rewards diff --git a/docs/source_en/Usage Guide/Quick-Start.md b/docs/source_en/Usage Guide/Quick-Start.md index d3629a19..34a45342 100644 --- a/docs/source_en/Usage Guide/Quick-Start.md +++ b/docs/source_en/Usage Guide/Quick-Start.md @@ -349,7 +349,7 @@ def main(): for sequence in sample_response.sequences: all_input_data.append(sequence.new_input_feature) - all_old_logps.append(sequence.logprobs) + all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs]) all_completion_lengths.append(len(sequence.tokens)) total_rewards, format_rewards, accuracy_rewards = compute_rewards( all_input_data diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" index 0987e6b7..6f731118 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" @@ -278,7 +278,7 @@ def main(): for sequence in sample_response.sequences: all_input_data.append(sequence.new_input_feature) - all_old_logps.append(sequence.logprobs) + all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs]) all_completion_lengths.append(len(sequence.tokens)) # 计算奖励 diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" index 230d5a51..253f620b 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" @@ -351,7 +351,7 @@ def main(): for sequence in sample_response.sequences: all_input_data.append(sequence.new_input_feature) - all_old_logps.append(sequence.logprobs) + all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs]) all_completion_lengths.append(len(sequence.tokens)) total_rewards, format_rewards, accuracy_rewards = compute_rewards( all_input_data diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 3dd7f8bf..eb5bed2a 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -405,6 +405,7 @@ def forward_backward(self, adapter_name = kwargs.pop('adapter_name', self._get_default_group()) temperature = float(kwargs.pop('temperature', 1.0)) forward_only = kwargs.pop('forward_only', False) + return_logits = kwargs.pop('return_logits', False) optimizer_config = self.optimizer_group[adapter_name] loss_instance = self.optimizer_group[adapter_name].loss_instance if not inputs: @@ -571,11 +572,15 @@ def forward_step_func(data_iterator, model): optimizer_config.inputs = inputs if logps and len({_logps.shape[1] for _logps in logps}) == 1: logps = torch.cat(logps, dim=0) + if logits and len({_logits.shape[1] for _logits in logits}) == 1: + logits = torch.cat(logits, dim=0) if isinstance(loss, torch.Tensor): loss = loss.detach().cpu().float().numpy() + if not return_logits: + logits = None if not forward_only: - optimizer_config.outputs = ModelOutput(logits=None, loss=loss, logps=logps) - return ModelOutput(logits=None, loss=loss, logps=logps) + optimizer_config.outputs = ModelOutput(logits=logits, loss=loss, logps=logps) + return ModelOutput(logits=logits, loss=loss, logps=logps) @remote_function(dispatch='all') def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type: int = 2, **kwargs): diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 93811dfc..d00e80ed 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -8,6 +8,7 @@ import torch import torch.distributed as dist import transformers +from copy import copy from dataclasses import dataclass, field from peft import PeftConfig, PeftModel, get_peft_model from peft.utils import load_peft_weights, set_peft_model_state_dict @@ -367,6 +368,7 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajec """ adapter_name = kwargs.pop('adapter_name', self._get_default_group()) temperature = float(kwargs.pop('temperature', 1.0)) + return_logits = kwargs.pop('return_logits', False) optimizer_config = self.optimizer_group[adapter_name] self._lazy_wrap_model() if not inputs: @@ -401,7 +403,10 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajec logits = outputs['logits'] logits.div_(temperature) outputs['logps'] = selective_log_softmax(logits, masked_labels) + outputs = copy(outputs) outputs['past_key_values'] = None + if not return_logits: + outputs['logits'] = None return outputs @remote_function(dispatch='slice_dp', collect=collect_tensor_dict) @@ -417,6 +422,7 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T """ adapter_name = kwargs.pop('adapter_name', self._get_default_group()) temperature = float(kwargs.pop('temperature', 1.0)) + return_logits = kwargs.pop('return_logits', False) optimizer_config = self.optimizer_group[adapter_name] self._lazy_wrap_model() if not inputs: @@ -452,7 +458,10 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T logits = outputs['logits'] logits.div_(temperature) outputs['logps'] = selective_log_softmax(logits, masked_labels) + outputs = copy(outputs) outputs['past_key_values'] = None + if not return_logits: + outputs['logits'] = None return outputs @remote_function(collect='mean') From 758ab1bb01c2fe4470855ee9cb7724e93af91a74 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 18 Mar 2026 21:11:53 +0800 Subject: [PATCH 49/56] fix docs --- .../Usage Guide/Introduction-with-Qwen3.5.md | 14 +++++++------- docs/source_en/Usage Guide/Quick-Start.md | 14 +++++++------- ...234\200\344\275\263\345\256\236\350\267\265.md" | 14 +++++++------- ...277\253\351\200\237\345\274\200\345\247\213.md" | 14 +++++++------- 4 files changed, 28 insertions(+), 28 deletions(-) diff --git a/docs/source_en/Usage Guide/Introduction-with-Qwen3.5.md b/docs/source_en/Usage Guide/Introduction-with-Qwen3.5.md index 4e53972d..105b2320 100644 --- a/docs/source_en/Usage Guide/Introduction-with-Qwen3.5.md +++ b/docs/source_en/Usage Guide/Introduction-with-Qwen3.5.md @@ -250,7 +250,7 @@ def main(): advantage_fn = GRPOAdvantage() metrics = CompletionRewardMetric() - sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS) + sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS, num_samples=1) optim_step = 0 logger.info(get_device_placement()) @@ -266,20 +266,20 @@ def main(): sampler.reset_prefix_cache() # Group sampling: sample NUM_GENERATIONS completions per prompt - sample_response = sampler.sample( + sample_responses = sampler.sample( global_prompts * NUM_GENERATIONS, sampling_params, - num_samples=1, ) all_input_data = [] all_old_logps = [] all_completion_lengths = [] - for sequence in sample_response.sequences: - all_input_data.append(sequence.new_input_feature) - all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs]) - all_completion_lengths.append(len(sequence.tokens)) + for sample_response in sample_responses: + for sequence in sample_response.sequences: + all_input_data.append(sequence.new_input_feature) + all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs]) + all_completion_lengths.append(len(sequence.tokens)) # Compute rewards total_rewards, format_rewards, accuracy_rewards = compute_rewards(all_input_data) diff --git a/docs/source_en/Usage Guide/Quick-Start.md b/docs/source_en/Usage Guide/Quick-Start.md index 34a45342..509f166d 100644 --- a/docs/source_en/Usage Guide/Quick-Start.md +++ b/docs/source_en/Usage Guide/Quick-Start.md @@ -327,7 +327,7 @@ def main(): ) advantage_fn = GRPOAdvantage() metrics = CompletionRewardMetric() - sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS) + sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS, num_samples=1) optim_step = 0 print(get_device_placement()) @@ -338,19 +338,19 @@ def main(): global_prompts = batch if isinstance(batch, list) else [batch] ckpt_manager.sync_weights(merge_and_sync=False) sampler.reset_prefix_cache() - sample_response = sampler.sample( + sample_responses = sampler.sample( global_prompts*NUM_GENERATIONS, sampling_params, - num_samples=1, ) all_input_data: List[Dict[str, Any]] = [] all_old_logps: List[List[float]] = [] all_completion_lengths: List[int] = [] - for sequence in sample_response.sequences: - all_input_data.append(sequence.new_input_feature) - all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs]) - all_completion_lengths.append(len(sequence.tokens)) + for sample_response in sample_responses: + for sequence in sample_response.sequences: + all_input_data.append(sequence.new_input_feature) + all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs]) + all_completion_lengths.append(len(sequence.tokens)) total_rewards, format_rewards, accuracy_rewards = compute_rewards( all_input_data ) diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" index 6f731118..56998d73 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" @@ -250,7 +250,7 @@ def main(): advantage_fn = GRPOAdvantage() metrics = CompletionRewardMetric() - sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS) + sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS, num_samples=1) optim_step = 0 logger.info(get_device_placement()) @@ -266,20 +266,20 @@ def main(): sampler.reset_prefix_cache() # 组采样:每个 prompt 采样 NUM_GENERATIONS 个结果 - sample_response = sampler.sample( + sample_responses = sampler.sample( global_prompts * NUM_GENERATIONS, sampling_params, - num_samples=1, ) all_input_data = [] all_old_logps = [] all_completion_lengths = [] - for sequence in sample_response.sequences: - all_input_data.append(sequence.new_input_feature) - all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs]) - all_completion_lengths.append(len(sequence.tokens)) + for sample_response in sample_responses: + for sequence in sample_response.sequences: + all_input_data.append(sequence.new_input_feature) + all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs]) + all_completion_lengths.append(len(sequence.tokens)) # 计算奖励 total_rewards, format_rewards, accuracy_rewards = compute_rewards(all_input_data) diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" index 253f620b..d59afc67 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" @@ -329,7 +329,7 @@ def main(): ) advantage_fn = GRPOAdvantage() metrics = CompletionRewardMetric() - sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS) + sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS, num_samples=1) optim_step = 0 print(get_device_placement()) @@ -340,19 +340,19 @@ def main(): global_prompts = batch if isinstance(batch, list) else [batch] ckpt_manager.sync_weights(merge_and_sync=False) sampler.reset_prefix_cache() - sample_response = sampler.sample( + sample_responses = sampler.sample( global_prompts*NUM_GENERATIONS, sampling_params, - num_samples=1, ) all_input_data: List[Dict[str, Any]] = [] all_old_logps: List[List[float]] = [] all_completion_lengths: List[int] = [] - for sequence in sample_response.sequences: - all_input_data.append(sequence.new_input_feature) - all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs]) - all_completion_lengths.append(len(sequence.tokens)) + for sample_response in sample_responses: + for sequence in sample_response.sequences: + all_input_data.append(sequence.new_input_feature) + all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs]) + all_completion_lengths.append(len(sequence.tokens)) total_rewards, format_rewards, accuracy_rewards = compute_rewards( all_input_data ) From aeb0ec127177f3c42ec162174131b88a3347f3e2 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 18 Mar 2026 21:18:57 +0800 Subject: [PATCH 50/56] fix --- docs/source_en/Usage Guide/Introduction-with-Qwen3.5.md | 2 +- docs/source_en/Usage Guide/Quick-Start.md | 2 +- .../Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" | 2 +- .../\345\277\253\351\200\237\345\274\200\345\247\213.md" | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source_en/Usage Guide/Introduction-with-Qwen3.5.md b/docs/source_en/Usage Guide/Introduction-with-Qwen3.5.md index 105b2320..978b5af1 100644 --- a/docs/source_en/Usage Guide/Introduction-with-Qwen3.5.md +++ b/docs/source_en/Usage Guide/Introduction-with-Qwen3.5.md @@ -250,7 +250,7 @@ def main(): advantage_fn = GRPOAdvantage() metrics = CompletionRewardMetric() - sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS, num_samples=1) + sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS, num_samples=1, logprobs=1) optim_step = 0 logger.info(get_device_placement()) diff --git a/docs/source_en/Usage Guide/Quick-Start.md b/docs/source_en/Usage Guide/Quick-Start.md index 509f166d..6a05a53f 100644 --- a/docs/source_en/Usage Guide/Quick-Start.md +++ b/docs/source_en/Usage Guide/Quick-Start.md @@ -327,7 +327,7 @@ def main(): ) advantage_fn = GRPOAdvantage() metrics = CompletionRewardMetric() - sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS, num_samples=1) + sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS, num_samples=1, logprobs=1) optim_step = 0 print(get_device_placement()) diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" index 56998d73..c8e92c3b 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" @@ -250,7 +250,7 @@ def main(): advantage_fn = GRPOAdvantage() metrics = CompletionRewardMetric() - sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS, num_samples=1) + sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS, num_samples=1, logprobs=1) optim_step = 0 logger.info(get_device_placement()) diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" index d59afc67..0b8e386a 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" @@ -329,7 +329,7 @@ def main(): ) advantage_fn = GRPOAdvantage() metrics = CompletionRewardMetric() - sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS, num_samples=1) + sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS, num_samples=1, logprobs=1) optim_step = 0 print(get_device_placement()) From c26a70890221e914fac08465a598bffe22ae5c67 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 18 Mar 2026 21:35:18 +0800 Subject: [PATCH 51/56] fix --- cookbook/rl/gkd_off_policy.py | 17 +++++++++-------- cookbook/rl/gkd_on_policy.py | 25 +++++++++++++------------ cookbook/rl/grpo.py | 2 +- 3 files changed, 23 insertions(+), 21 deletions(-) diff --git a/cookbook/rl/gkd_off_policy.py b/cookbook/rl/gkd_off_policy.py index 92f82784..3315c962 100644 --- a/cookbook/rl/gkd_off_policy.py +++ b/cookbook/rl/gkd_off_policy.py @@ -6,10 +6,11 @@ Pipeline: 1. DataLoader supplies full-text batches (prompt + reference answer). 2. Teacher vLLM sampler computes top-k prompt logprobs on the sequences. - 3. Student TransformersModel runs forward_backward() with GKDLoss. + 3. Student MegatronModel runs forward_backward() with GKDLoss. Key difference from on-policy: - No student sampler needed (responses already in the dataset). + - No weight sync needed (student doesn't sample). - Faster per-step (no generation latency), but less exploration. Architecture (Ray): @@ -20,20 +21,20 @@ │ student_model.forward_backward(teacher_output=...) ──► GKD │ └─────────────────────────────────────────────────────────────────┘ │ - vLLMSampler + TransformersModel + vLLMSampler + MegatronModel (teacher) (student) Environment variables (all optional): STUDENT_MODEL_ID – (default: ms://Qwen/Qwen3-0.6B) TEACHER_MODEL_ID – (default: ms://Qwen/Qwen3-8B) - MODEL_GPUS – GPUs for student model (default: 4) - SAMPLER_GPUS – GPUs for teacher vLLM sampler (default: 2) - BATCH_SIZE – global batch size (default: 8) - MAX_STEPS – total optimisation steps (default: 200) - LR – learning rate (default: 1e-4) + MODEL_GPUS – GPUs for student model (default: 8) + SAMPLER_GPUS – GPUs for teacher vLLM sampler (default: 8) + BATCH_SIZE – global batch size (default: 16) + MAX_STEPS – total optimisation steps (default: 1000) + LR – learning rate (default: 5e-5) GKD_BETA – JSD beta (0=fwd KL, 1=rev KL) (default: 0.5) GKD_TEMPERATURE – distillation temperature (default: 1.0) - GKD_TOPK – top-k vocab for teacher logprobs (default: 20) + GKD_TOPK – top-k vocab for teacher logprobs (default: 64) """ import os diff --git a/cookbook/rl/gkd_on_policy.py b/cookbook/rl/gkd_on_policy.py index 7c1a9af4..f134f0de 100644 --- a/cookbook/rl/gkd_on_policy.py +++ b/cookbook/rl/gkd_on_policy.py @@ -5,7 +5,7 @@ to match the teacher's token distribution. Pipeline: - 1. DataLoader supplies prompt-only batches. + 1. Sync student model weights to student vLLM sampler. 2. Student vLLM sampler generates completions on-the-fly. 3. Teacher vLLM sampler computes top-k prompt logprobs on generated sequences. 4. Student TransformersModel runs forward_backward() with GKDLoss. @@ -13,27 +13,28 @@ Architecture (Ray): ┌─────────────────────────────────────────────────────────────────┐ │ Driver (CPU) │ - │ dataloader ──► prompt-only batch │ + │ ckpt_manager.sync_weights() ──► sync LoRA to student sampler │ │ student_sampler.sample() ──► on-policy completions │ - │ teacher_sampler.sample(topk_prompt_logprobs=k) ──► teacher lps│ + │ teacher_sampler.sample(prompt_logprobs=k) ──► teacher lps │ │ student_model.forward_backward(teacher_output=...) ──► GKD │ └─────────────────────────────────────────────────────────────────┘ │ │ │ DataLoader vLLMSampler ×2 TransformersModel - (model GPUs) student + teacher (model GPUs) + student + teacher (model GPUs) Environment variables (all optional): - STUDENT_MODEL_ID – (default: ms://Qwen/Qwen2.5-1.5B-Instruct) - TEACHER_MODEL_ID – (default: ms://Qwen/Qwen3-4B) + STUDENT_MODEL_ID – (default: ms://Qwen/Qwen3-0.6B) + TEACHER_MODEL_ID – (default: ms://Qwen/Qwen3-8B) MODEL_GPUS – GPUs for student model (default: 4) - SAMPLER_GPUS – GPUs for each vLLM sampler (default: 2) - MAX_NEW_TOKENS – max completion tokens (default: 512) - BATCH_SIZE – global prompt-level batch size (default: 8) - MAX_STEPS – total optimisation steps (default: 200) - LR – learning rate (default: 1e-4) + SAMPLER_GPUS – GPUs for each vLLM sampler (default: 4) + MAX_NEW_TOKENS – max completion tokens (default: 2048) + BATCH_SIZE – global prompt-level batch size (default: 16) + MAX_STEPS – total optimisation steps (default: 1000) + LR – learning rate (default: 5e-5) + N_SAMPLES – samples per prompt (default: 1) GKD_BETA – JSD beta (0=fwd KL, 1=rev KL) (default: 0.5) GKD_TEMPERATURE – distillation temperature (default: 1.0) - GKD_TOPK – top-k vocab for teacher logprobs (default: 10) + GKD_TOPK – top-k vocab for teacher logprobs (default: 64) """ import os diff --git a/cookbook/rl/grpo.py b/cookbook/rl/grpo.py index e35dc648..d7d5df21 100644 --- a/cookbook/rl/grpo.py +++ b/cookbook/rl/grpo.py @@ -20,7 +20,7 @@ logger = get_logger() -MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3-0.6B') +MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B') USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '0'))) MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) From e952bde26200248bdc693d202d58e7e1a72c6e1a Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 18 Mar 2026 21:38:08 +0800 Subject: [PATCH 52/56] fix --- src/twinkle/model/megatron/megatron.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index eb5bed2a..ddddf41c 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -365,10 +365,6 @@ def calculate_loss(self, **kwargs): def backward(self, **kwargs): raise NotImplementedError('Megatron only supports `forward_backward` and `forward_only`') - @remote_function(collect='first', lazy_collect=False) - def get_lr(self): - return self.optimizer_group['default']._get_lr() - @remote_function(dispatch='slice_dp', collect=collect_tensor_dict, sync=True) def forward_backward(self, *, From 4facc4f44e5852eac3df310f8e3630da61229fa2 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 18 Mar 2026 21:39:58 +0800 Subject: [PATCH 53/56] fix --- src/twinkle/model/megatron/multi_lora_megatron.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/twinkle/model/megatron/multi_lora_megatron.py b/src/twinkle/model/megatron/multi_lora_megatron.py index 4a2e3fdb..b674ef46 100644 --- a/src/twinkle/model/megatron/multi_lora_megatron.py +++ b/src/twinkle/model/megatron/multi_lora_megatron.py @@ -12,10 +12,10 @@ from twinkle import DeviceMesh, remote_class, remote_function, requires, template, torch_util from twinkle.data_format import InputFeature, Trajectory from twinkle.hub import HubOperation +from twinkle.infra import collect_tensor_dict from twinkle.loss import Loss from twinkle.metric import Metric from twinkle.processor import InputProcessor -from ...infra import collect_tensor_dict from ..multi_lora import MultiLora from .megatron import MegatronModel from .strategy import MegatronStrategy From 29db904565b497b9ea57a89e538718bda2018d87 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Thu, 19 Mar 2026 11:50:15 +0800 Subject: [PATCH 54/56] fix --- client_tools/client_generator.py | 2 +- cookbook/client/server/megatron/run.sh | 6 ++++++ cookbook/client/server/transformer/run.sh | 6 ++++++ cookbook/client/twinkle/self_host/grpo.py | 1 + src/twinkle/server/sampler/twinkle_handlers.py | 7 +++---- src/twinkle_client/sampler/vllm_sampler.py | 2 +- src/twinkle_client/types/__init__.py | 1 + src/twinkle_client/types/sampler.py | 9 +++++++-- 8 files changed, 26 insertions(+), 8 deletions(-) create mode 100644 cookbook/client/server/megatron/run.sh create mode 100644 cookbook/client/server/transformer/run.sh diff --git a/client_tools/client_generator.py b/client_tools/client_generator.py index c0df54d3..08700f51 100644 --- a/client_tools/client_generator.py +++ b/client_tools/client_generator.py @@ -795,7 +795,7 @@ def sample( json_data=json_data ) response.raise_for_status() - return SampleResponseModel(**response.json()) + return [SampleResponseModel(**r) for r in response.json()['samples']] def set_template(self, template_cls: str, adapter_name: str = '', **kwargs) -> SetTemplateResponse: """Set the template for encoding trajectories.""" diff --git a/cookbook/client/server/megatron/run.sh b/cookbook/client/server/megatron/run.sh new file mode 100644 index 00000000..c8bf96d2 --- /dev/null +++ b/cookbook/client/server/megatron/run.sh @@ -0,0 +1,6 @@ +export RAY_ROTATION_MAX_BYTES=1024 +export RAY_ROTATION_BACKUP_COUNT=1 +CUDA_VISIBLE_DEVICES=0,1,2,3 ray start --head --port=6379 --num-gpus=4 --disable-usage-stats --include-dashboard=false +CUDA_VISIBLE_DEVICES=4,5,6,7 ray start --address=127.0.0.1:6379 --num-gpus=4 +CUDA_VISIBLE_DEVICES="" ray start --address=127.0.0.1:6379 --num-gpus=0 +python server.py \ No newline at end of file diff --git a/cookbook/client/server/transformer/run.sh b/cookbook/client/server/transformer/run.sh new file mode 100644 index 00000000..c8bf96d2 --- /dev/null +++ b/cookbook/client/server/transformer/run.sh @@ -0,0 +1,6 @@ +export RAY_ROTATION_MAX_BYTES=1024 +export RAY_ROTATION_BACKUP_COUNT=1 +CUDA_VISIBLE_DEVICES=0,1,2,3 ray start --head --port=6379 --num-gpus=4 --disable-usage-stats --include-dashboard=false +CUDA_VISIBLE_DEVICES=4,5,6,7 ray start --address=127.0.0.1:6379 --num-gpus=4 +CUDA_VISIBLE_DEVICES="" ray start --address=127.0.0.1:6379 --num-gpus=0 +python server.py \ No newline at end of file diff --git a/cookbook/client/twinkle/self_host/grpo.py b/cookbook/client/twinkle/self_host/grpo.py index bfc0f643..cabce6ea 100644 --- a/cookbook/client/twinkle/self_host/grpo.py +++ b/cookbook/client/twinkle/self_host/grpo.py @@ -127,6 +127,7 @@ def train(): 'temperature': TEMPERATURE, 'top_p': 0.95, 'num_samples': NUM_GENERATIONS, + 'logprobs': 1, } # Track the current adapter path for sampling diff --git a/src/twinkle/server/sampler/twinkle_handlers.py b/src/twinkle/server/sampler/twinkle_handlers.py index 83136f57..47dfd04b 100644 --- a/src/twinkle/server/sampler/twinkle_handlers.py +++ b/src/twinkle/server/sampler/twinkle_handlers.py @@ -59,10 +59,10 @@ def create(request: Request, self: SamplerManagement = Depends(self_fn)) -> type """Health check / session creation endpoint.""" return types.CreateResponse() - @app.post('/twinkle/sample', response_model=types.SampleResponseModel) + @app.post('/twinkle/sample', response_model=types.SampleResponseModelList) def sample( request: Request, body: types.SampleRequest, - self: SamplerManagement = Depends(self_fn)) -> list[types.SampleResponseModel]: + self: SamplerManagement = Depends(self_fn)) -> types.SampleResponseModelList: """Sample completions from the model. Supports Trajectory or InputFeature inputs, with optional LoRA adapter. @@ -98,7 +98,6 @@ def sample( params = None if body.sampling_params: params = SamplingParams.from_dict(body.sampling_params) - params.num_samples = body.num_samples # Call sampler responses = self.sampler.sample( @@ -129,7 +128,7 @@ def sample( prompt_logprobs=response.prompt_logprobs, topk_prompt_logprobs=response.topk_prompt_logprobs, )) - return sample_models + return types.SampleResponseModelList(samples=sample_models) except Exception: logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=traceback.format_exc()) diff --git a/src/twinkle_client/sampler/vllm_sampler.py b/src/twinkle_client/sampler/vllm_sampler.py index ea5e8767..13083a39 100644 --- a/src/twinkle_client/sampler/vllm_sampler.py +++ b/src/twinkle_client/sampler/vllm_sampler.py @@ -84,7 +84,7 @@ def sample( json_data=json_data ) response.raise_for_status() - return [SampleResponseModel(**r) for r in response.json()] + return [SampleResponseModel(**r) for r in response.json()['samples']] def set_template(self, template_cls: str, adapter_name: str = '', **kwargs) -> SetTemplateResponse: """Set the template for encoding trajectories.""" diff --git a/src/twinkle_client/types/__init__.py b/src/twinkle_client/types/__init__.py index b6650a28..59c88597 100644 --- a/src/twinkle_client/types/__init__.py +++ b/src/twinkle_client/types/__init__.py @@ -60,6 +60,7 @@ SampledSequenceModel, SampleRequest, SampleResponseModel, + SampleResponseModelList, SetTemplateRequest as SamplerSetTemplateRequest, SetTemplateResponse as SamplerSetTemplateResponse, ) diff --git a/src/twinkle_client/types/sampler.py b/src/twinkle_client/types/sampler.py index a1b579b7..f15c675a 100644 --- a/src/twinkle_client/types/sampler.py +++ b/src/twinkle_client/types/sampler.py @@ -24,20 +24,25 @@ class SampledSequenceModel(BaseModel): """A single sampled sequence, mirroring twinkle.data_format.SampledSequence.""" stop_reason: StopReason = Field(..., description="Stop reason: 'length' or 'stop'") tokens: List[int] = Field(..., description='Token IDs of the sampled sequence') - logprobs: Optional[List[float]] = Field(None, description='Per-token log-probabilities') + logprobs: Optional[List[Optional[List[Tuple[int, float]]]]] = Field(None, description='Per-token log-probabilities') decoded: Optional[str] = Field(None, description='Decoded text of the sampled sequence') new_input_feature: Optional[Dict[str, Any]] = Field( None, description='Updated InputFeature after sampling (input_ids, labels, etc.)') class SampleResponseModel(BaseModel): - """Response body for the /sample endpoint, mirroring twinkle.data_format.SampleResponse.""" + """Mirroring twinkle.data_format.SampleResponse.""" sequences: List[SampledSequenceModel] = Field( ..., description='List of sampled sequences') prompt_logprobs: Optional[List[Optional[float]]] = None topk_prompt_logprobs: Optional[List[Optional[List[Tuple[int, float]]]]] = None +class SampleResponseModelList(BaseModel): + """Response body for the /sample endpoint""" + samples: List[SampleResponseModel] = Field(..., description='List of sample responses') + + class SetTemplateRequest(BaseModel): """Request body for the /set_template endpoint.""" template_cls: str = Field(..., description="Template class name (e.g. 'Template')") From 05be92f1ba9f7af7a8cc6c31366f1ca6c69beec6 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Thu, 19 Mar 2026 11:51:05 +0800 Subject: [PATCH 55/56] lint --- cookbook/client/server/megatron/run.sh | 2 +- cookbook/client/server/transformer/run.sh | 2 +- src/twinkle/server/sampler/twinkle_handlers.py | 5 ++--- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/cookbook/client/server/megatron/run.sh b/cookbook/client/server/megatron/run.sh index c8bf96d2..38befef2 100644 --- a/cookbook/client/server/megatron/run.sh +++ b/cookbook/client/server/megatron/run.sh @@ -3,4 +3,4 @@ export RAY_ROTATION_BACKUP_COUNT=1 CUDA_VISIBLE_DEVICES=0,1,2,3 ray start --head --port=6379 --num-gpus=4 --disable-usage-stats --include-dashboard=false CUDA_VISIBLE_DEVICES=4,5,6,7 ray start --address=127.0.0.1:6379 --num-gpus=4 CUDA_VISIBLE_DEVICES="" ray start --address=127.0.0.1:6379 --num-gpus=0 -python server.py \ No newline at end of file +python server.py diff --git a/cookbook/client/server/transformer/run.sh b/cookbook/client/server/transformer/run.sh index c8bf96d2..38befef2 100644 --- a/cookbook/client/server/transformer/run.sh +++ b/cookbook/client/server/transformer/run.sh @@ -3,4 +3,4 @@ export RAY_ROTATION_BACKUP_COUNT=1 CUDA_VISIBLE_DEVICES=0,1,2,3 ray start --head --port=6379 --num-gpus=4 --disable-usage-stats --include-dashboard=false CUDA_VISIBLE_DEVICES=4,5,6,7 ray start --address=127.0.0.1:6379 --num-gpus=4 CUDA_VISIBLE_DEVICES="" ray start --address=127.0.0.1:6379 --num-gpus=0 -python server.py \ No newline at end of file +python server.py diff --git a/src/twinkle/server/sampler/twinkle_handlers.py b/src/twinkle/server/sampler/twinkle_handlers.py index 47dfd04b..93a49b40 100644 --- a/src/twinkle/server/sampler/twinkle_handlers.py +++ b/src/twinkle/server/sampler/twinkle_handlers.py @@ -60,9 +60,8 @@ def create(request: Request, self: SamplerManagement = Depends(self_fn)) -> type return types.CreateResponse() @app.post('/twinkle/sample', response_model=types.SampleResponseModelList) - def sample( - request: Request, body: types.SampleRequest, - self: SamplerManagement = Depends(self_fn)) -> types.SampleResponseModelList: + def sample(request: Request, body: types.SampleRequest, + self: SamplerManagement = Depends(self_fn)) -> types.SampleResponseModelList: """Sample completions from the model. Supports Trajectory or InputFeature inputs, with optional LoRA adapter. From 9fdd627406eae6c724d19676b55d427c6686d21b Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Thu, 19 Mar 2026 11:53:08 +0800 Subject: [PATCH 56/56] fix --- client_tools/client_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client_tools/client_generator.py b/client_tools/client_generator.py index 08700f51..0f4957f4 100644 --- a/client_tools/client_generator.py +++ b/client_tools/client_generator.py @@ -768,7 +768,7 @@ def sample( adapter_name: str = '', adapter_uri: Optional[str] = None, num_samples: int = 1, - ) -> SampleResponseModel: + ) -> List[SampleResponseModel]: """Sample from the model. Args: