Skip to content

Commit 1cf2543

Browse files
Support GKD and on-policy distillation (#112)
1 parent cb52a6c commit 1cf2543

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+1437
-477
lines changed

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,8 @@ Or use ModelScope's [official image](https://www.modelscope.cn/docs/intro/enviro
101101

102102
## Changelog
103103

104+
- 🎉2026-03-19 Support GKD training ,please refer to this [cookbook](cookbook/rl/gkd_on_policy.py).
104105
- 🎉2026-02-13 Initial version of Twinkle✨ released, including SFT/PT/RL support for text models.
105-
We also made available serverless training capabilities on [ModelScope](https://modelscope.cn) via
106-
Tinker-compatible APIs.
107106

108107
## Training as a Service on ModelScope
109108

README_ZH.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ Twinkle✨支持相同的算法接口运行在单GPU、torchrun多机、Ray、Cl
9191

9292
## 更新日志
9393

94+
🎉2026-03-19 支持GKD蒸馏能力,参考[cookbook](cookbook/rl/gkd_on_policy.py)
9495
🎉2026-02-13 Twinkle✨ 初始版本发布,支持文本模型的SFT/PT/RL训练。我们还通过兼容Tinker的API,在魔搭社区上提供了无服务器训练功能。
9596

9697
## ModelScope 的训练服务

client_tools/client_generator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -768,7 +768,7 @@ def sample(
768768
adapter_name: str = '',
769769
adapter_uri: Optional[str] = None,
770770
num_samples: int = 1,
771-
) -> SampleResponseModel:
771+
) -> List[SampleResponseModel]:
772772
"""Sample from the model.
773773
774774
Args:
@@ -795,7 +795,7 @@ def sample(
795795
json_data=json_data
796796
)
797797
response.raise_for_status()
798-
return SampleResponseModel(**response.json())
798+
return [SampleResponseModel(**r) for r in response.json()['samples']]
799799
800800
def set_template(self, template_cls: str, adapter_name: str = '', **kwargs) -> SetTemplateResponse:
801801
"""Set the template for encoding trajectories."""
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
export RAY_ROTATION_MAX_BYTES=1024
2+
export RAY_ROTATION_BACKUP_COUNT=1
3+
CUDA_VISIBLE_DEVICES=0,1,2,3 ray start --head --port=6379 --num-gpus=4 --disable-usage-stats --include-dashboard=false
4+
CUDA_VISIBLE_DEVICES=4,5,6,7 ray start --address=127.0.0.1:6379 --num-gpus=4
5+
CUDA_VISIBLE_DEVICES="" ray start --address=127.0.0.1:6379 --num-gpus=0
6+
python server.py
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
export RAY_ROTATION_MAX_BYTES=1024
2+
export RAY_ROTATION_BACKUP_COUNT=1
3+
CUDA_VISIBLE_DEVICES=0,1,2,3 ray start --head --port=6379 --num-gpus=4 --disable-usage-stats --include-dashboard=false
4+
CUDA_VISIBLE_DEVICES=4,5,6,7 ray start --address=127.0.0.1:6379 --num-gpus=4
5+
CUDA_VISIBLE_DEVICES="" ray start --address=127.0.0.1:6379 --num-gpus=0
6+
python server.py

cookbook/client/twinkle/self_host/grpo.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
from twinkle_client.dataset import Dataset
3939
from twinkle_client.model import MultiLoraTransformersModel
4040
from twinkle_client.sampler import vLLMSampler
41-
from twinkle.preprocessor.llm import GSM8KProcessor
4241

4342
logger = get_logger()
4443

@@ -127,6 +126,8 @@ def train():
127126
'max_tokens': MAX_NEW_TOKENS,
128127
'temperature': TEMPERATURE,
129128
'top_p': 0.95,
129+
'num_samples': NUM_GENERATIONS,
130+
'logprobs': 1,
130131
}
131132

132133
# Track the current adapter path for sampling
@@ -153,21 +154,21 @@ def train():
153154
logger.info(f'Step {step}: Saved weights to {current_adapter_uri}')
154155

155156
# ========== 2. Sample completions ==========
156-
sample_response = sampler.sample(
157+
sample_responses = sampler.sample(
157158
inputs=prompts,
158159
sampling_params=sampling_params,
159160
adapter_uri=current_adapter_uri,
160-
num_samples=NUM_GENERATIONS,
161161
)
162162

163163
all_input_data: List[Dict[str, Any]] = []
164164
all_old_logps: List[List[float]] = []
165165
all_completion_lengths: List[int] = []
166166

167-
for sequence in sample_response.sequences:
168-
all_input_data.append(sequence.new_input_feature)
169-
all_old_logps.append(sequence.logprobs)
170-
all_completion_lengths.append(len(sequence.tokens))
167+
for sample_response in sample_responses:
168+
for sequence in sample_response.sequences:
169+
all_input_data.append(sequence.new_input_feature)
170+
all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs])
171+
all_completion_lengths.append(len(sequence.tokens))
171172

172173
# ========== 3. Compute rewards ==========
173174

cookbook/client/twinkle/self_host/sample.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,29 +66,30 @@ def sample():
6666
sampling_params = {
6767
'max_tokens': 128,
6868
'temperature': 1.0,
69+
'num_samples': num_samples,
6970
}
7071

7172
# Step 7: Call the sampler
7273
# - inputs: list of Trajectory dicts (will be encoded server-side using the template)
7374
# - sampling_params: controls generation behavior
7475
# - adapter_uri: optional LoRA adapter path for fine-tuned inference
7576
# - num_samples: number of completions per prompt
76-
response = sampler.sample(
77+
responses = sampler.sample(
7778
inputs=[trajectory] * num_prompts,
7879
sampling_params=sampling_params,
7980
adapter_uri=ADAPTER_URI,
80-
num_samples=num_samples,
8181
)
8282

8383
# Step 8: Decode and print the results
8484
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
8585

86-
logger.info(f'Generated {len(response.sequences)} sequences '
87-
f'({num_prompts} prompts x {num_samples} samples)')
86+
for response in responses:
87+
logger.info(f'Generated {len(response.sequences)} sequences '
88+
f'({num_prompts} prompts x {num_samples} samples)')
8889

89-
for i, seq in enumerate(response.sequences):
90-
text = tokenizer.decode(seq.tokens, skip_special_tokens=True)
91-
logger.info(f'Sequence {i}:\n {text}\n')
90+
for i, seq in enumerate(response.sequences):
91+
text = tokenizer.decode(seq.tokens, skip_special_tokens=True)
92+
logger.info(f'Sequence {i}:\n {text}\n')
9293

9394

9495
if __name__ == '__main__':

cookbook/rl/gkd_off_policy.py

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
"""GKD Off-Policy Distillation via Ray.
2+
3+
Off-policy knowledge distillation: the student learns to match the teacher's
4+
token distribution on pre-existing reference responses from the dataset.
5+
6+
Pipeline:
7+
1. DataLoader supplies full-text batches (prompt + reference answer).
8+
2. Teacher vLLM sampler computes top-k prompt logprobs on the sequences.
9+
3. Student MegatronModel runs forward_backward() with GKDLoss.
10+
11+
Key difference from on-policy:
12+
- No student sampler needed (responses already in the dataset).
13+
- No weight sync needed (student doesn't sample).
14+
- Faster per-step (no generation latency), but less exploration.
15+
16+
Architecture (Ray):
17+
┌─────────────────────────────────────────────────────────────────┐
18+
│ Driver (CPU) │
19+
│ dataloader ──► full-text batch (prompt + reference answer) │
20+
│ teacher_sampler.sample(prompt_logprobs=k) ──► teacher lps │
21+
│ student_model.forward_backward(teacher_output=...) ──► GKD │
22+
└─────────────────────────────────────────────────────────────────┘
23+
24+
vLLMSampler + MegatronModel
25+
(teacher) (student)
26+
27+
Environment variables (all optional):
28+
STUDENT_MODEL_ID – (default: ms://Qwen/Qwen3-0.6B)
29+
TEACHER_MODEL_ID – (default: ms://Qwen/Qwen3-8B)
30+
MODEL_GPUS – GPUs for student model (default: 8)
31+
SAMPLER_GPUS – GPUs for teacher vLLM sampler (default: 8)
32+
BATCH_SIZE – global batch size (default: 16)
33+
MAX_STEPS – total optimisation steps (default: 1000)
34+
LR – learning rate (default: 5e-5)
35+
GKD_BETA – JSD beta (0=fwd KL, 1=rev KL) (default: 0.5)
36+
GKD_TEMPERATURE – distillation temperature (default: 1.0)
37+
GKD_TOPK – top-k vocab for teacher logprobs (default: 64)
38+
"""
39+
40+
import os
41+
from typing import List, Optional
42+
43+
import torch
44+
from peft import LoraConfig
45+
46+
import twinkle
47+
from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger
48+
from twinkle.data_format import SamplingParams
49+
from twinkle.dataloader import DataLoader
50+
from twinkle.dataset import Dataset, DatasetMeta
51+
from twinkle.loss import GKDLoss
52+
from twinkle.model import MegatronModel
53+
from twinkle.preprocessor import GSM8KProcessor
54+
from twinkle.sampler import vLLMSampler
55+
from twinkle.template import Template
56+
57+
logger = get_logger()
58+
59+
# ── Configuration ─────────────────────────────────────────────────────────────
60+
STUDENT_MODEL_ID = os.environ.get('STUDENT_MODEL_ID', 'ms://Qwen/Qwen3-0.6B')
61+
TEACHER_MODEL_ID = os.environ.get('TEACHER_MODEL_ID', 'ms://Qwen/Qwen3-8B')
62+
63+
MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 8))
64+
SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 8))
65+
NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS
66+
67+
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 16))
68+
MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000))
69+
LEARNING_RATE = float(os.environ.get('LR', 5e-5))
70+
71+
GKD_BETA = float(os.environ.get('GKD_BETA', 0.5))
72+
GKD_TEMPERATURE = float(os.environ.get('GKD_TEMPERATURE', 1.0))
73+
GKD_TOPK = int(os.environ.get('GKD_TOPK', 64))
74+
ADAPTER_NAME = 'default'
75+
SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem step by step and put '
76+
'your final answer within #### <number>')
77+
78+
79+
# ── Dataset ───────────────────────────────────────────────────────────────────
80+
81+
def create_dataset():
82+
"""Full-text dataset with prompt + reference answer for off-policy distillation."""
83+
dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train'))
84+
dataset.set_template('Template', model_id=STUDENT_MODEL_ID, max_length=1024)
85+
dataset.map(GSM8KProcessor(system=SYSTEM_PROMPT, add_assistant=True))
86+
return dataset
87+
88+
89+
# ── Utility ───────────────────────────────────────────────────────────────────
90+
91+
def convert_topk_prompt_logprobs(
92+
topk_prompt_logprobs_batch: List[Optional[List[List[tuple]]]],
93+
) -> dict:
94+
"""Convert vLLM topk_prompt_logprobs to GKDLoss teacher_output format.
95+
96+
Args:
97+
topk_prompt_logprobs_batch: [batch] each is topk_prompt_logprobs for one request.
98+
Shape: [seq_len, topk] per request, where each position is List[(token_id, logprob)].
99+
100+
Returns:
101+
Dict with teacher logprobs/indices tensors.
102+
"""
103+
batch_logprobs = []
104+
batch_indices = []
105+
106+
for seq_topk in topk_prompt_logprobs_batch:
107+
seq_logprobs = []
108+
seq_indices = []
109+
if seq_topk is not None:
110+
for pos_topk in seq_topk:
111+
if pos_topk is None:
112+
# First position is None, fill with placeholder
113+
seq_logprobs.append([0.0] * GKD_TOPK)
114+
seq_indices.append([0] * GKD_TOPK)
115+
else:
116+
seq_logprobs.append([lp for _, lp in pos_topk])
117+
seq_indices.append([tid for tid, _ in pos_topk])
118+
batch_logprobs.append(seq_logprobs)
119+
batch_indices.append(seq_indices)
120+
121+
# Pad to same seq_len within batch
122+
max_len = max(len(seq) for seq in batch_logprobs) if batch_logprobs else 1
123+
124+
for i in range(len(batch_logprobs)):
125+
pad_len = max_len - len(batch_logprobs[i])
126+
if pad_len > 0:
127+
batch_logprobs[i].extend([[0.0] * GKD_TOPK] * pad_len)
128+
batch_indices[i].extend([[0] * GKD_TOPK] * pad_len)
129+
130+
# Roll to align with labels (first position has no valid logprobs)
131+
return {
132+
'teacher_topk_logprobs': torch.roll(torch.tensor(batch_logprobs, dtype=torch.float32), shifts=-1, dims=1),
133+
'teacher_topk_indices': torch.roll(torch.tensor(batch_indices, dtype=torch.long), shifts=-1, dims=1),
134+
}
135+
136+
137+
# ── Training ──────────────────────────────────────────────────────────────────
138+
139+
def main():
140+
device_groups = [
141+
DeviceGroup(name='student_model', ranks=MODEL_GPUS, device_type='cuda'),
142+
DeviceGroup(name='teacher_sampler', ranks=SAMPLER_GPUS, device_type='cuda'),
143+
]
144+
model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS)
145+
sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS)
146+
147+
twinkle.initialize(
148+
mode='ray',
149+
nproc_per_node=NUM_GPUS,
150+
groups=device_groups,
151+
)
152+
logger.info(get_device_placement())
153+
154+
# ── Student model (trainable) ──────────────────────────────────────────────
155+
student_model = MegatronModel(
156+
model_id=STUDENT_MODEL_ID,
157+
device_mesh=model_mesh,
158+
remote_group='student_model',
159+
)
160+
student_model.add_adapter_to_model(
161+
ADAPTER_NAME,
162+
LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, target_modules='all-linear'),
163+
)
164+
student_model.set_optimizer('default', lr=LEARNING_RATE)
165+
student_model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS)
166+
student_model.set_loss(GKDLoss(beta=GKD_BETA, temperature=GKD_TEMPERATURE))
167+
student_model.set_template('Template', model_id=STUDENT_MODEL_ID)
168+
169+
# ── Teacher vLLM sampler (for prompt logprobs) ─────────────────────────────
170+
teacher_sampler = vLLMSampler(
171+
model_id=TEACHER_MODEL_ID,
172+
engine_args={'gpu_memory_utilization': 0.85, 'max_model_len': 10240, 'logprobs_mode': 'raw_logprobs', 'max_logprobs': 64},
173+
device_mesh=sampler_mesh,
174+
remote_group='teacher_sampler',
175+
)
176+
teacher_sampler.set_template(Template, model_id=TEACHER_MODEL_ID)
177+
178+
# ── DataLoader (full-text: prompt + reference answer) ──────────────────────
179+
dataloader = DataLoader(
180+
dataset=create_dataset,
181+
batch_size=BATCH_SIZE,
182+
min_batch_size=BATCH_SIZE,
183+
remote_group='student_model',
184+
)
185+
186+
logger.info(f'GKD Off-Policy | student={STUDENT_MODEL_ID} teacher={TEACHER_MODEL_ID}')
187+
logger.info(f' beta={GKD_BETA} T={GKD_TEMPERATURE} topk={GKD_TOPK}')
188+
189+
optim_step = 0
190+
for batch in dataloader:
191+
if optim_step >= MAX_STEPS:
192+
break
193+
194+
# Teacher vLLM computes top-k prompt logprobs on the reference sequences
195+
# max_tokens=0: don't generate new content, just compute logprobs on input
196+
teacher_response = teacher_sampler.sample(
197+
batch,
198+
SamplingParams(max_tokens=0, temperature=1.0, prompt_logprobs=GKD_TOPK, num_samples=1),
199+
)
200+
201+
input_data = [seq.new_input_feature for response in teacher_response for seq in response.sequences]
202+
203+
# Convert teacher logprobs to tensor format for GKDLoss
204+
teacher_output = convert_topk_prompt_logprobs(
205+
[resp.topk_prompt_logprobs for resp in teacher_response],
206+
)
207+
208+
# Student forward + GKD backward
209+
student_model.forward_backward(inputs=input_data, **teacher_output)
210+
student_model.clip_grad_and_step()
211+
212+
if optim_step > 0 and optim_step % 10 == 0:
213+
metric = student_model.calculate_metric(is_training=True)
214+
logger.info(f'[Step {optim_step}/{MAX_STEPS}] {metric}')
215+
216+
if optim_step > 0 and optim_step % 50 == 0:
217+
student_model.save(f'gkd-offpolicy-ckpt-{optim_step}')
218+
219+
optim_step += 1
220+
221+
student_model.save('gkd-offpolicy-final')
222+
logger.info('GKD off-policy training completed.')
223+
224+
225+
if __name__ == '__main__':
226+
main()

0 commit comments

Comments
 (0)