|
| 1 | +"""GKD Off-Policy Distillation via Ray. |
| 2 | +
|
| 3 | +Off-policy knowledge distillation: the student learns to match the teacher's |
| 4 | +token distribution on pre-existing reference responses from the dataset. |
| 5 | +
|
| 6 | +Pipeline: |
| 7 | + 1. DataLoader supplies full-text batches (prompt + reference answer). |
| 8 | + 2. Teacher vLLM sampler computes top-k prompt logprobs on the sequences. |
| 9 | + 3. Student MegatronModel runs forward_backward() with GKDLoss. |
| 10 | +
|
| 11 | +Key difference from on-policy: |
| 12 | + - No student sampler needed (responses already in the dataset). |
| 13 | + - No weight sync needed (student doesn't sample). |
| 14 | + - Faster per-step (no generation latency), but less exploration. |
| 15 | +
|
| 16 | +Architecture (Ray): |
| 17 | + ┌─────────────────────────────────────────────────────────────────┐ |
| 18 | + │ Driver (CPU) │ |
| 19 | + │ dataloader ──► full-text batch (prompt + reference answer) │ |
| 20 | + │ teacher_sampler.sample(prompt_logprobs=k) ──► teacher lps │ |
| 21 | + │ student_model.forward_backward(teacher_output=...) ──► GKD │ |
| 22 | + └─────────────────────────────────────────────────────────────────┘ |
| 23 | + │ |
| 24 | + vLLMSampler + MegatronModel |
| 25 | + (teacher) (student) |
| 26 | +
|
| 27 | +Environment variables (all optional): |
| 28 | + STUDENT_MODEL_ID – (default: ms://Qwen/Qwen3-0.6B) |
| 29 | + TEACHER_MODEL_ID – (default: ms://Qwen/Qwen3-8B) |
| 30 | + MODEL_GPUS – GPUs for student model (default: 8) |
| 31 | + SAMPLER_GPUS – GPUs for teacher vLLM sampler (default: 8) |
| 32 | + BATCH_SIZE – global batch size (default: 16) |
| 33 | + MAX_STEPS – total optimisation steps (default: 1000) |
| 34 | + LR – learning rate (default: 5e-5) |
| 35 | + GKD_BETA – JSD beta (0=fwd KL, 1=rev KL) (default: 0.5) |
| 36 | + GKD_TEMPERATURE – distillation temperature (default: 1.0) |
| 37 | + GKD_TOPK – top-k vocab for teacher logprobs (default: 64) |
| 38 | +""" |
| 39 | + |
| 40 | +import os |
| 41 | +from typing import List, Optional |
| 42 | + |
| 43 | +import torch |
| 44 | +from peft import LoraConfig |
| 45 | + |
| 46 | +import twinkle |
| 47 | +from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger |
| 48 | +from twinkle.data_format import SamplingParams |
| 49 | +from twinkle.dataloader import DataLoader |
| 50 | +from twinkle.dataset import Dataset, DatasetMeta |
| 51 | +from twinkle.loss import GKDLoss |
| 52 | +from twinkle.model import MegatronModel |
| 53 | +from twinkle.preprocessor import GSM8KProcessor |
| 54 | +from twinkle.sampler import vLLMSampler |
| 55 | +from twinkle.template import Template |
| 56 | + |
| 57 | +logger = get_logger() |
| 58 | + |
| 59 | +# ── Configuration ───────────────────────────────────────────────────────────── |
| 60 | +STUDENT_MODEL_ID = os.environ.get('STUDENT_MODEL_ID', 'ms://Qwen/Qwen3-0.6B') |
| 61 | +TEACHER_MODEL_ID = os.environ.get('TEACHER_MODEL_ID', 'ms://Qwen/Qwen3-8B') |
| 62 | + |
| 63 | +MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 8)) |
| 64 | +SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 8)) |
| 65 | +NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS |
| 66 | + |
| 67 | +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 16)) |
| 68 | +MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000)) |
| 69 | +LEARNING_RATE = float(os.environ.get('LR', 5e-5)) |
| 70 | + |
| 71 | +GKD_BETA = float(os.environ.get('GKD_BETA', 0.5)) |
| 72 | +GKD_TEMPERATURE = float(os.environ.get('GKD_TEMPERATURE', 1.0)) |
| 73 | +GKD_TOPK = int(os.environ.get('GKD_TOPK', 64)) |
| 74 | +ADAPTER_NAME = 'default' |
| 75 | +SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem step by step and put ' |
| 76 | + 'your final answer within #### <number>') |
| 77 | + |
| 78 | + |
| 79 | +# ── Dataset ─────────────────────────────────────────────────────────────────── |
| 80 | + |
| 81 | +def create_dataset(): |
| 82 | + """Full-text dataset with prompt + reference answer for off-policy distillation.""" |
| 83 | + dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train')) |
| 84 | + dataset.set_template('Template', model_id=STUDENT_MODEL_ID, max_length=1024) |
| 85 | + dataset.map(GSM8KProcessor(system=SYSTEM_PROMPT, add_assistant=True)) |
| 86 | + return dataset |
| 87 | + |
| 88 | + |
| 89 | +# ── Utility ─────────────────────────────────────────────────────────────────── |
| 90 | + |
| 91 | +def convert_topk_prompt_logprobs( |
| 92 | + topk_prompt_logprobs_batch: List[Optional[List[List[tuple]]]], |
| 93 | +) -> dict: |
| 94 | + """Convert vLLM topk_prompt_logprobs to GKDLoss teacher_output format. |
| 95 | +
|
| 96 | + Args: |
| 97 | + topk_prompt_logprobs_batch: [batch] each is topk_prompt_logprobs for one request. |
| 98 | + Shape: [seq_len, topk] per request, where each position is List[(token_id, logprob)]. |
| 99 | +
|
| 100 | + Returns: |
| 101 | + Dict with teacher logprobs/indices tensors. |
| 102 | + """ |
| 103 | + batch_logprobs = [] |
| 104 | + batch_indices = [] |
| 105 | + |
| 106 | + for seq_topk in topk_prompt_logprobs_batch: |
| 107 | + seq_logprobs = [] |
| 108 | + seq_indices = [] |
| 109 | + if seq_topk is not None: |
| 110 | + for pos_topk in seq_topk: |
| 111 | + if pos_topk is None: |
| 112 | + # First position is None, fill with placeholder |
| 113 | + seq_logprobs.append([0.0] * GKD_TOPK) |
| 114 | + seq_indices.append([0] * GKD_TOPK) |
| 115 | + else: |
| 116 | + seq_logprobs.append([lp for _, lp in pos_topk]) |
| 117 | + seq_indices.append([tid for tid, _ in pos_topk]) |
| 118 | + batch_logprobs.append(seq_logprobs) |
| 119 | + batch_indices.append(seq_indices) |
| 120 | + |
| 121 | + # Pad to same seq_len within batch |
| 122 | + max_len = max(len(seq) for seq in batch_logprobs) if batch_logprobs else 1 |
| 123 | + |
| 124 | + for i in range(len(batch_logprobs)): |
| 125 | + pad_len = max_len - len(batch_logprobs[i]) |
| 126 | + if pad_len > 0: |
| 127 | + batch_logprobs[i].extend([[0.0] * GKD_TOPK] * pad_len) |
| 128 | + batch_indices[i].extend([[0] * GKD_TOPK] * pad_len) |
| 129 | + |
| 130 | + # Roll to align with labels (first position has no valid logprobs) |
| 131 | + return { |
| 132 | + 'teacher_topk_logprobs': torch.roll(torch.tensor(batch_logprobs, dtype=torch.float32), shifts=-1, dims=1), |
| 133 | + 'teacher_topk_indices': torch.roll(torch.tensor(batch_indices, dtype=torch.long), shifts=-1, dims=1), |
| 134 | + } |
| 135 | + |
| 136 | + |
| 137 | +# ── Training ────────────────────────────────────────────────────────────────── |
| 138 | + |
| 139 | +def main(): |
| 140 | + device_groups = [ |
| 141 | + DeviceGroup(name='student_model', ranks=MODEL_GPUS, device_type='cuda'), |
| 142 | + DeviceGroup(name='teacher_sampler', ranks=SAMPLER_GPUS, device_type='cuda'), |
| 143 | + ] |
| 144 | + model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS) |
| 145 | + sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS) |
| 146 | + |
| 147 | + twinkle.initialize( |
| 148 | + mode='ray', |
| 149 | + nproc_per_node=NUM_GPUS, |
| 150 | + groups=device_groups, |
| 151 | + ) |
| 152 | + logger.info(get_device_placement()) |
| 153 | + |
| 154 | + # ── Student model (trainable) ────────────────────────────────────────────── |
| 155 | + student_model = MegatronModel( |
| 156 | + model_id=STUDENT_MODEL_ID, |
| 157 | + device_mesh=model_mesh, |
| 158 | + remote_group='student_model', |
| 159 | + ) |
| 160 | + student_model.add_adapter_to_model( |
| 161 | + ADAPTER_NAME, |
| 162 | + LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, target_modules='all-linear'), |
| 163 | + ) |
| 164 | + student_model.set_optimizer('default', lr=LEARNING_RATE) |
| 165 | + student_model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS) |
| 166 | + student_model.set_loss(GKDLoss(beta=GKD_BETA, temperature=GKD_TEMPERATURE)) |
| 167 | + student_model.set_template('Template', model_id=STUDENT_MODEL_ID) |
| 168 | + |
| 169 | + # ── Teacher vLLM sampler (for prompt logprobs) ───────────────────────────── |
| 170 | + teacher_sampler = vLLMSampler( |
| 171 | + model_id=TEACHER_MODEL_ID, |
| 172 | + engine_args={'gpu_memory_utilization': 0.85, 'max_model_len': 10240, 'logprobs_mode': 'raw_logprobs', 'max_logprobs': 64}, |
| 173 | + device_mesh=sampler_mesh, |
| 174 | + remote_group='teacher_sampler', |
| 175 | + ) |
| 176 | + teacher_sampler.set_template(Template, model_id=TEACHER_MODEL_ID) |
| 177 | + |
| 178 | + # ── DataLoader (full-text: prompt + reference answer) ────────────────────── |
| 179 | + dataloader = DataLoader( |
| 180 | + dataset=create_dataset, |
| 181 | + batch_size=BATCH_SIZE, |
| 182 | + min_batch_size=BATCH_SIZE, |
| 183 | + remote_group='student_model', |
| 184 | + ) |
| 185 | + |
| 186 | + logger.info(f'GKD Off-Policy | student={STUDENT_MODEL_ID} teacher={TEACHER_MODEL_ID}') |
| 187 | + logger.info(f' beta={GKD_BETA} T={GKD_TEMPERATURE} topk={GKD_TOPK}') |
| 188 | + |
| 189 | + optim_step = 0 |
| 190 | + for batch in dataloader: |
| 191 | + if optim_step >= MAX_STEPS: |
| 192 | + break |
| 193 | + |
| 194 | + # Teacher vLLM computes top-k prompt logprobs on the reference sequences |
| 195 | + # max_tokens=0: don't generate new content, just compute logprobs on input |
| 196 | + teacher_response = teacher_sampler.sample( |
| 197 | + batch, |
| 198 | + SamplingParams(max_tokens=0, temperature=1.0, prompt_logprobs=GKD_TOPK, num_samples=1), |
| 199 | + ) |
| 200 | + |
| 201 | + input_data = [seq.new_input_feature for response in teacher_response for seq in response.sequences] |
| 202 | + |
| 203 | + # Convert teacher logprobs to tensor format for GKDLoss |
| 204 | + teacher_output = convert_topk_prompt_logprobs( |
| 205 | + [resp.topk_prompt_logprobs for resp in teacher_response], |
| 206 | + ) |
| 207 | + |
| 208 | + # Student forward + GKD backward |
| 209 | + student_model.forward_backward(inputs=input_data, **teacher_output) |
| 210 | + student_model.clip_grad_and_step() |
| 211 | + |
| 212 | + if optim_step > 0 and optim_step % 10 == 0: |
| 213 | + metric = student_model.calculate_metric(is_training=True) |
| 214 | + logger.info(f'[Step {optim_step}/{MAX_STEPS}] {metric}') |
| 215 | + |
| 216 | + if optim_step > 0 and optim_step % 50 == 0: |
| 217 | + student_model.save(f'gkd-offpolicy-ckpt-{optim_step}') |
| 218 | + |
| 219 | + optim_step += 1 |
| 220 | + |
| 221 | + student_model.save('gkd-offpolicy-final') |
| 222 | + logger.info('GKD off-policy training completed.') |
| 223 | + |
| 224 | + |
| 225 | +if __name__ == '__main__': |
| 226 | + main() |
0 commit comments